resnet.py 1.17 KB
Newer Older
1
import os
2
3
from pathlib import Path

4
5
import torch
from torchvision.datasets import CIFAR10
6
7
8
9
10
11
from torchvision.models import resnet18
from torchvision.transforms import transforms

from colossalai.legacy.utils import get_dataloader

from .registry import non_distributed_component_funcs
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


def get_cifar10_dataloader(train):
    # build dataloaders
    dataset = CIFAR10(root=Path(os.environ['DATA']),
                      download=True,
                      train=train,
                      transform=transforms.Compose(
                          [transforms.ToTensor(),
                           transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
    dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
    return dataloader


@non_distributed_component_funcs.register(name='resnet18')
def get_resnet_training_components():
28
29
30
31

    def model_builder(checkpoint=False):
        return resnet18(num_classes=10)

32
33
    trainloader = get_cifar10_dataloader(train=True)
    testloader = get_cifar10_dataloader(train=False)
34

35
    criterion = torch.nn.CrossEntropyLoss()
36
    return model_builder, trainloader, testloader, torch.optim.Adam, criterion