face_data.py 2.32 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

from PIL import Image
import io
import logging
import numbers

import torch
from torch.utils.data import Dataset

from fastreid.data.common import CommDataset

logger = logging.getLogger("fastreid.face_data")

try:
    import mxnet as mx
except ImportError:
    logger.info("Please install mxnet if you want to use .rec file")


class MXFaceDataset(Dataset):
    def __init__(self, path_imgrec, transforms):
        super().__init__()
        self.transforms = transforms

        logger.info(f"loading recordio {path_imgrec}...")
        path_imgidx = path_imgrec[0:-4] + ".idx"
        self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
        s = self.imgrec.read_idx(0)
        header, _ = mx.recordio.unpack(s)
        if header.flag > 0:
            # logger.debug(f"header0 label: {header.label}")
            self.header0 = (int(header.label[0]), int(header.label[1]))
            self.imgidx = list(range(1, int(header.label[0])))
            # logger.debug(self.imgidx)
        else:
            self.imgidx = list(self.imgrec.keys)
        logger.info(f"Number of Samples: {len(self.imgidx)}, "
                    f"Number of Classes: {int(self.header0[1] - self.header0[0])}")

    def __getitem__(self, index):
        idx = self.imgidx[index]
        s = self.imgrec.read_idx(idx)
        header, img = mx.recordio.unpack(s)
        label = header.label
        if not isinstance(label, numbers.Number):
            label = label[0]
        label = torch.tensor(label, dtype=torch.long)

        sample = Image.open(io.BytesIO(img))  # RGB
        if self.transforms is not None: sample = self.transforms(sample)
        return {
            "images": sample,
            "targets": label,
            "camids": 0,
        }

    def __len__(self):
        # logger.debug(f"mxface dataset length is {len(self.imgidx)}")
        return len(self.imgidx)

    @property
    def num_classes(self):
        return int(self.header0[1] - self.header0[0])


class TestFaceDataset(CommDataset):
    def __init__(self, img_items, labels):
        self.img_items = img_items
        self.labels = labels

    def __getitem__(self, index):
        img = torch.tensor(self.img_items[index]) * 127.5 + 127.5
        return {
            "images": img,
        }