test_pipe.py 8.68 KB
Newer Older
1
import os
2
import copy
3
4
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

import pytest

import deepspeed
import deepspeed.runtime.utils as ds_utils

14

15
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
aiss's avatar
aiss committed
16

17
PipeTopo = PipeDataParallelTopology
18
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
19

aiss's avatar
aiss committed
20
from .common import distributed_test
21
22
23
24
25
26
27


def rel_diff(A, B):
    return abs(A - B) / abs(A)


# All models
aiss's avatar
aiss committed
28
from .simple_model import args_from_dict
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77


class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,
                      64,
                      kernel_size=11,
                      stride=4,
                      padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,
                         stride=2),
            nn.Conv2d(64,
                      192,
                      kernel_size=5,
                      padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,
                         stride=2),
            nn.Conv2d(192,
                      384,
                      kernel_size=3,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,
                      256,
                      kernel_size=3,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256,
                      256,
                      kernel_size=3,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,
                         stride=2),
        )
        self.classifier = nn.Linear(256, num_classes)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return self.loss_fn(x, y)


78
79
80
81
82
83
84
class AlexNetPipe(AlexNet):
    def to_layers(self):
        layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
        return layers


class AlexNetPipeSpec(PipelineModule):
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def __init__(self, num_classes=10, **kwargs):
        self.num_classes = num_classes
        specs = [
            LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
            LayerSpec(nn.ReLU, inplace=True),
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
            LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
            F.relu,
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
            LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
            F.relu,
            LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),

            lambda x: x.view(x.size(0), -1),
            LayerSpec(nn.Linear, 256, self.num_classes), # classifier
        ]
        super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)


def cifar_trainset(fp16=False):
    import torchvision
    import torchvision.transforms as transforms

    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.5,
                              0.5,
                              0.5),
                             (0.5,
                              0.5,
                              0.5)),
    ]
    if fp16:
        transform_list.append(torchvision.transforms.Lambda(lambda x: x.half()))

    transform = transforms.Compose(transform_list)

    local_rank = torch.cuda.current_device()

    # Only one rank per machine downloads.
    dist.barrier()
    if local_rank != 0:
        dist.barrier()
    trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10-data',
                                            train=True,
                                            download=True,
                                            transform=transform)
    if local_rank == 0:
        dist.barrier()
    return trainset


def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
    with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
        ds_utils.set_random_seed(seed)

145
146
147
        # disable dropout
        model.eval()

148
149
150
151
152
153
154
155
156
157
158
159
160
        trainset = cifar_trainset(fp16=fp16)
        args.local_rank = dist.get_rank()

        engine, _, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=[p for p in model.parameters()],
            training_data=trainset)

        losses = []
        for step in range(num_steps):
            loss = engine.train_batch()
            losses.append(loss.item())
161
            if step % 50 == 0 and dist.get_rank() == 0:
162
163
164
165
166
167
168
169
170
171
172
                print(f'STEP={step} LOSS={loss.item()}')

        if average_dp_losses:
            loss_tensor = torch.tensor(losses).cuda()
            dist.all_reduce(loss_tensor)
            loss_tensor /= dist.get_world_size()
            losses = loss_tensor.tolist()

    return losses


Jeff Rasley's avatar
Jeff Rasley committed
173
@pytest.mark.skip(reason="been seeing nondeterministic failures, skipping for now")
174
@pytest.mark.parametrize('topo',
175
                         [
176
177
178
179
180
181
                             PipeTopo(num_pp=1,
                                      num_dp=4),
                             PipeTopo(num_pp=2,
                                      num_dp=2),
                             PipeTopo(num_pp=4,
                                      num_dp=1),
182
                         ])
183
def test_pipe_cifar10(topo, tmpdir):
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    config_dict = {
        "train_batch_size": 16,
        "train_micro_batch_size_per_gpu": 4,
        "steps_per_print": 20,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.001,
                "betas": [0.9,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "zero_optimization": {
            "stage": 0
        },
        "fp16": {
            "enabled": False
        },
        "pipeline": {
            "seed_layers": True,
            "activation_checkpoint_interval": 1
        }
    }
    args = args_from_dict(tmpdir, config_dict)

211
212
213
    # Allocate model for consistent initial weights.
    init_net = AlexNetPipe()

214
    @distributed_test(world_size=4)
215
    def _helper(topo, tmpdir, steps=500):
216
217
        assert steps >= 100

218
219
220
221
222
223
        base_net = copy.deepcopy(init_net)
        base_model = PipelineModule(layers=base_net.to_layers(),
                                    num_stages=1,
                                    loss_fn=nn.CrossEntropyLoss())

        # Train with just data parallelism
224
225
226
227
228
        base_losses = train_cifar(base_model,
                                  args,
                                  num_steps=steps,
                                  fp16=config_dict['fp16']['enabled'])

229
230
231
232
233
234
235
236
        test_net = copy.deepcopy(init_net)
        test_model = PipelineModule(layers=test_net.to_layers(),
                                    topology=topo,
                                    loss_fn=nn.CrossEntropyLoss())

        #test_model = AlexNetPipe(num_classes=10,
        #                         topology=test_topo,
        #                         seed_layers=config_dict['pipeline']['seed_layers'])
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        test_losses = train_cifar(test_model,
                                  args,
                                  num_steps=steps,
                                  fp16=config_dict['fp16']['enabled'])

        abs_diffs = [l0 - l1 for l0, l1 in zip(base_losses, test_losses)]
        rel_diffs = [rel_diff(l0, l1) for l0, l1 in zip(base_losses, test_losses)]
        if dist.get_rank() == 0:
            print(
                f'abs min={min(abs_diffs)} max={max(abs_diffs)} avg={sum(abs_diffs)/len(abs_diffs)}'
            )
            print(
                f'rel min={min(rel_diffs)} max={max(rel_diffs)} avg={sum(rel_diffs)/len(rel_diffs)}'
            )
            print(
                f'first: base={base_losses[0]} test={test_losses[0]} abs={abs_diffs[0]} rel={rel_diffs[0]}'
            )

            for lastX in [1, 10, 100]:
                base_avg = sum(base_losses[-lastX:]) / lastX
                test_avg = sum(test_losses[-lastX:]) / lastX
                print(
                    f'last-{lastX}: base={base_avg} test={test_avg} abs={base_avg - test_avg} rel={rel_diff(base_avg, test_avg)}'
                )

        lastX = 100
        base = base_losses[-lastX:]
        base_avg = sum(base) / len(base)
        test = test_losses[-lastX:]
        test_avg = sum(test) / len(test)
        assert rel_diff(base_avg, test_avg) < 0.03

269
    _helper(topo, tmpdir)