alexnet_model.py 4.56 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
import deepspeed.comm as dist
import deepspeed.runtime.utils as ds_utils
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec


class AlexNet(nn.Module):
aiss's avatar
aiss committed
18

aiss's avatar
aiss committed
19
20
21
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
aiss's avatar
aiss committed
22
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
aiss's avatar
aiss committed
23
            nn.ReLU(inplace=True),
aiss's avatar
aiss committed
24
25
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
aiss's avatar
aiss committed
26
            nn.ReLU(inplace=True),
aiss's avatar
aiss committed
27
28
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
aiss's avatar
aiss committed
29
            nn.ReLU(inplace=True),
aiss's avatar
aiss committed
30
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
aiss's avatar
aiss committed
31
            nn.ReLU(inplace=True),
aiss's avatar
aiss committed
32
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
aiss's avatar
aiss committed
33
            nn.ReLU(inplace=True),
aiss's avatar
aiss committed
34
            nn.MaxPool2d(kernel_size=2, stride=2),
aiss's avatar
aiss committed
35
36
37
38
39
40
41
42
43
44
45
46
        )
        self.classifier = nn.Linear(256, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return self.loss_fn(x, y)


class AlexNetPipe(AlexNet):
aiss's avatar
aiss committed
47

aiss's avatar
aiss committed
48
49
50
51
52
53
    def to_layers(self):
        layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
        return layers


class AlexNetPipeSpec(PipelineModule):
aiss's avatar
aiss committed
54

aiss's avatar
aiss committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def __init__(self, num_classes=10, **kwargs):
        self.num_classes = num_classes
        specs = [
            LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
            LayerSpec(nn.ReLU, inplace=True),
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
            LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
            F.relu,
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
            LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
            lambda x: x.view(x.size(0), -1),
aiss's avatar
aiss committed
72
            LayerSpec(nn.Linear, 256, self.num_classes),  # classifier
aiss's avatar
aiss committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        ]
        super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)


# Define this here because we cannot pickle local lambda functions
def cast_to_half(x):
    return x.half()


def cifar_trainset(fp16=False):
    torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
    import torchvision.transforms as transforms

    transform_list = [
        transforms.ToTensor(),
aiss's avatar
aiss committed
88
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
aiss's avatar
aiss committed
89
90
91
92
93
94
95
96
97
98
99
100
    ]
    if fp16:
        transform_list.append(torchvision.transforms.Lambda(cast_to_half))

    transform = transforms.Compose(transform_list)

    local_rank = get_accelerator().current_device()

    # Only one rank per machine downloads.
    dist.barrier()
    if local_rank != 0:
        dist.barrier()
aiss's avatar
aiss committed
101
    trainset = torchvision.datasets.CIFAR10(root='/blob/cifar10-data', train=True, download=True, transform=transform)
aiss's avatar
aiss committed
102
103
104
105
106
    if local_rank == 0:
        dist.barrier()
    return trainset


aiss's avatar
aiss committed
107
108
def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
    with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]):
aiss's avatar
aiss committed
109
110
111
112
113
114
115
116
        ds_utils.set_random_seed(seed)

        # disable dropout
        model.eval()

        trainset = cifar_trainset(fp16=fp16)
        config['local_rank'] = dist.get_rank()

aiss's avatar
aiss committed
117
118
119
120
        engine, _, _, _ = deepspeed.initialize(config=config,
                                               model=model,
                                               model_parameters=[p for p in model.parameters()],
                                               training_data=trainset)
aiss's avatar
aiss committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

        losses = []
        for step in range(num_steps):
            loss = engine.train_batch()
            losses.append(loss.item())
            if step % 50 == 0 and dist.get_rank() == 0:
                print(f'STEP={step} LOSS={loss.item()}')

        if average_dp_losses:
            loss_tensor = torch.tensor(losses).to(get_accelerator().device_name())
            dist.all_reduce(loss_tensor)
            loss_tensor /= dist.get_world_size()
            losses = loss_tensor.tolist()

    return losses