resnet.py 1.11 KB
Newer Older
1
import os
2
3
from pathlib import Path

4
5
import torch
from torchvision.datasets import CIFAR10
6
7
8
9
10
11
from torchvision.models import resnet18
from torchvision.transforms import transforms

from colossalai.legacy.utils import get_dataloader

from .registry import non_distributed_component_funcs
12
13
14
15


def get_cifar10_dataloader(train):
    # build dataloaders
16
17
18
19
20
21
22
23
    dataset = CIFAR10(
        root=Path(os.environ["DATA"]),
        download=True,
        train=train,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
        ),
    )
24
25
26
27
    dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
    return dataloader


28
@non_distributed_component_funcs.register(name="resnet18")
29
def get_resnet_training_components():
30
31
32
    def model_builder(checkpoint=False):
        return resnet18(num_classes=10)

33
34
    trainloader = get_cifar10_dataloader(train=True)
    testloader = get_cifar10_dataloader(train=False)
35

36
    criterion = torch.nn.CrossEntropyLoss()
37
    return model_builder, trainloader, testloader, torch.optim.Adam, criterion