datasets.py 2.47 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

import os

from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset

__all__ = ["Cars196", "CUB", "SOP", "InShop"]


@DATASET_REGISTRY.register()
class Cars196(ImageDataset):
    dataset_dir = 'Cars_196'
    dataset_name = "cars"

    def __init__(self, root='datasets', **kwargs):
        self.root = root
        self.dataset_dir = os.path.join(self.root, self.dataset_dir)
        train_file = os.path.join(self.dataset_dir, "train.txt")
        test_file = os.path.join(self.dataset_dir, "test.txt")

        required_files = [
            self.dataset_dir,
            train_file,
            test_file,
        ]
        self.check_before_run(required_files)

        train = self.process_label_file(train_file, is_train=True)
        query = self.process_label_file(test_file, is_train=False)

        super(Cars196, self).__init__(train, query, [], **kwargs)

    def process_label_file(self, file, is_train):
        data_list = []
        with open(file, 'r') as f:
            lines = f.read().splitlines()

        for line in lines:
            img_name, label = line.split(',')
            if is_train:
                label = self.dataset_name + '_' + str(label)

            data_list.append((os.path.join(self.dataset_dir, img_name), label, '0'))

        return data_list


@DATASET_REGISTRY.register()
class CUB(Cars196):
    dataset_dir = "CUB_200_2011"
    dataset_name = "cub"


@DATASET_REGISTRY.register()
class SOP(Cars196):
    dataset_dir = "Stanford_Online_Products"
    dataset_name = "sop"


@DATASET_REGISTRY.register()
class InShop(Cars196):
    dataset_dir = "InShop"
    dataset_name = "inshop"

    def __init__(self, root="datasets", **kwargs):
        self.root = root
        self.dataset_dir = os.path.join(self.root, self.dataset_dir)
        train_file = os.path.join(self.dataset_dir, "train.txt")
        query_file = os.path.join(self.dataset_dir, "test_query.txt")
        gallery_file = os.path.join(self.dataset_dir, "test_gallery.txt")

        required_files = [
            train_file,
            query_file,
            gallery_file,
        ]
        self.check_before_run(required_files)

        train = self.process_label_file(train_file, True)
        query = self.process_label_file(query_file, False)
        gallery = self.process_label_file(gallery_file, False)

        super(Cars196, self).__init__(train, query, gallery, **kwargs)