attr_dataset.py 1.15 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""

import torch
from torch.utils.data import Dataset

from fastreid.data.data_utils import read_image


class AttrDataset(Dataset):
    """Image Person Attribute Dataset"""

    def __init__(self, img_items, transform, attr_dict):
        self.img_items = img_items
        self.transform = transform
        self.attr_dict = attr_dict

    def __len__(self):
        return len(self.img_items)

    def __getitem__(self, index):
        img_path, labels = self.img_items[index]
        img = read_image(img_path)

        if self.transform is not None: img = self.transform(img)

        labels = torch.as_tensor(labels)

        return {
            "images": img,
            "targets": labels,
            "img_paths": img_path,
        }

    @property
    def num_classes(self):
        return len(self.attr_dict)

    @property
    def sample_weights(self):
        sample_weights = torch.zeros(self.num_classes, dtype=torch.float32)
        for _, attr in self.img_items:
            sample_weights += torch.as_tensor(attr)
        sample_weights /= len(self)
        return sample_weights