"ts/webui/src/components/trial-detail/Intermediate.tsx" did not exist on "c329379db22fc19ba1038ea42cda491ff86039fe"
cifar100.py 2.17 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from omegaconf import OmegaConf
from flowvision import transforms
from flowvision.data.mixup import Mixup
from flowvision.transforms import InterpolationMode
from flowvision.transforms.functional import str_to_interp_mode

from libai.data.datasets import CIFAR100Dataset
from libai.data.build import build_image_train_loader, build_image_test_loader
from libai.config import LazyCall

# mean and std of cifar100 dataset
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

train_aug = LazyCall(transforms.Compose)(
    transforms=[
        LazyCall(transforms.RandomResizedCrop)(
            size=(224, 224),
            scale=(0.08, 1.0),
            ratio=(3.0 / 4.0, 4.0 / 3.0),
            interpolation=str_to_interp_mode("bicubic"),
        ),
        LazyCall(transforms.RandomHorizontalFlip)(),
        LazyCall(transforms.ToTensor)(),
        LazyCall(transforms.Normalize)(mean=CIFAR100_TRAIN_MEAN, std=CIFAR100_TRAIN_STD),
    ]
)

test_aug = LazyCall(transforms.Compose)(
    transforms=[
        LazyCall(transforms.Resize)(
            size=256,
            interpolation=InterpolationMode.BICUBIC,
        ),
        LazyCall(transforms.CenterCrop)(
            size=224,
        ),
        LazyCall(transforms.ToTensor)(),
        LazyCall(transforms.Normalize)(
            mean=CIFAR100_TRAIN_MEAN,
            std=CIFAR100_TRAIN_STD,
        ),
    ]
)


# Dataloader config
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_image_train_loader)(
    dataset=[
        LazyCall(CIFAR100Dataset)(
            root="./",
            train=True,
            download=True,
            transform=train_aug,
        ),
    ],
    num_workers=4,
    mixup_func=LazyCall(Mixup)(
        mixup_alpha=0.8,
        cutmix_alpha=1.0,
        prob=1.0,
        switch_prob=0.5,
        mode="batch",
        num_classes=100,
    ),
)

dataloader.test = [
    LazyCall(build_image_test_loader)(
        dataset=LazyCall(CIFAR100Dataset)(
            root="./",
            train=False,
            download=True,
            transform=test_aug,
        ),
        num_workers=4,
    )
]