dataset_test.py 1.36 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""

import sys
sys.path.append('.')
from data import get_dataloader
from config import cfg
import argparse
from data.datasets import init_dataset
# cfg.DATALOADER.SAMPLER = 'triplet'
cfg.DATASETS.NAMES = ("market1501", "dukemtmc", "cuhk03", "msmt17",)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument(
        '-cfg', "--config_file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str
    )
    # parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    cfg.merge_from_list(args.opts)

    # dataset = init_dataset('msmt17', combineall=True)
    get_dataloader(cfg)
    # tng_dataloader, val_dataloader, num_classes, num_query = get_dataloader(cfg)
    # def get_ex(): return open_image('datasets/beijingStation/query/000245_c10s2_1561732033722.000000.jpg')
    # im = get_ex()
    # print(data.train_ds[0])
    # print(data.test_ds[0])
    # a = next(iter(data.train_dl))
    # from IPython import embed; embed()
    # from ipdb import set_trace; set_trace()
    # im.apply_tfms(crop_pad(size=(300, 300)))