data.py 1.24 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torchvision.transforms.transforms as T

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler


def prepare_dataloader(data_root, 
                       train=True,
                       batch_size = 512):
    if train:
        train_transform = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomAffine(degrees=15, translate=(0.1,0.1)),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        train_dataset = CIFAR10(data_root, train=True, transform=train_transform, download=True)
        
        sampler = DistributedSampler(train_dataset)
        
        train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size, sampler=sampler)
        
        return train_dataloader, sampler
    else:
        test_transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        test_dataset = CIFAR10(data_root, train=False, transform=test_transform, download=True)
        
        test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=16)
        
        return test_dataloader, None