import torchvision.transforms.transforms as T

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler


def prepare_dataloader(data_root, 
                       train=True,
                       batch_size = 512):
    if train:
        train_transform = T.Compose([
            T.Resize((224, 224)),
            T.RandomHorizontalFlip(p=0.5),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        train_dataset = CIFAR10(data_root, train=True, transform=train_transform, download=True)
        
        sampler = DistributedSampler(train_dataset)
        
        train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size, sampler=sampler)
        
        return train_dataloader, sampler
    else:
        test_transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        test_dataset = CIFAR10(data_root, train=False, transform=test_transform, download=True)
        
        test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=16)
        
        return test_dataloader, None
