import torchvision.transforms.transforms as T from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler def prepare_dataloader(data_root, train=True, batch_size = 512): if train: train_transform = T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomAffine(degrees=15, translate=(0.1,0.1)), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_dataset = CIFAR10(data_root, train=True, transform=train_transform, download=True) sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size, sampler=sampler) return train_dataloader, sampler else: test_transform = T.Compose([ T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_dataset = CIFAR10(data_root, train=False, transform=test_transform, download=True) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=16) return test_dataloader, None