get_transform.py 3 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
from torchvision.transforms import *
from .transforms import *

def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans=True, **kwargs):
    normalize = Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    base_size = base_size if base_size is not None else int(1.0 * crop_size / 0.875)
    if dataset == 'imagenet':
        train_transforms = []
        val_transforms = []
        if rand_aug:
            from .autoaug import RandAugment
            train_transforms.append(RandAugment(2, 12))
        if etrans:
            train_transforms.extend([
                ERandomCrop(crop_size),
            ])
            val_transforms.extend([
                ECenterCrop(crop_size),
            ])
            
        else:
            train_transforms.extend([
                RandomResizedCrop(crop_size),
            ])
            val_transforms.extend([
                Resize(base_size),
                CenterCrop(crop_size),
            ])
        train_transforms.extend([
                RandomHorizontalFlip(),
            ColorJitter(0.4, 0.4, 0.4),
            ToTensor(),
            Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
            normalize,
        ])
        val_transforms.extend([
            ToTensor(),
            normalize,
        ])
        transform_train = Compose(train_transforms)
        transform_val = Compose(val_transforms)
    elif dataset == 'minc':
        transform_train = Compose([
            Resize(base_size),
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ColorJitter(0.4, 0.4, 0.4),
            ToTensor(),
            Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
            normalize,
        ])
        transform_val = Compose([
            Resize(base_size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ])
    elif dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                    (0.2023, 0.1994, 0.2010)),
        ])
    return transform_train, transform_val

_imagenet_pca = {
    'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
    'eigvec': torch.Tensor([
        [-0.5675,  0.7192,  0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948,  0.4203],
    ])
}