oss.py 8.13 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
import math
import os
import time
8
from typing import Any, List, cast
9
10
11
12
13
14
15
16
17
18

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

19
from fairscale.nn.data_parallel import ShardedDataParallel
20
21
22
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
23
OPTIM = torch.optim.RMSprop
24
25
26
27
28
29
30
31


def dist_init(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)


32
def get_problem(rank, data_size, batch_size):
33
34
35
36
37
38
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
39
40
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
41
42
43
44
45
46
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    return model, dataloader, loss_fn


def train_oss_ddp(
    rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
):

    # DDP
    dist_init(rank, world_size)

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    ddp = ShardedDataParallel(
        module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
    )
    optimizer = ddp.optimizer
64

65
66
67
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                if dist.get_rank() == 0:
                    print(f"Loss: {loss.item()}")

                ddp.reduce()  # Send the gradients to the appropriate shards
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
):
    # DDP
    dist_init(rank, world_size)

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

133
    # Shard the optimizer
134
135
136
137
138
139
140
141
    optimizer: torch.optim.Optimizer = (
        OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        if use_oss
        else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
    )

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()
167
168
169
170
171

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
172
            optimizer.consolidate_state_dict()
173
            if dist.get_rank() == 0:
174
                _ = optimizer.state_dict()
175
176
                print("... State dict collected")

177
        measurements.append(data_size / (epoch_end - epoch_start))
178
179
        if dist.get_rank() == 0:
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
180
181
182
183
184
185
186
187
188

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

189
190
191
192
193
194
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

195
    if use_oss and check_regression and dist.get_rank() == 0:
196
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
197
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
198
199
200
201
202
203
204
205
206
207
208
209
        print("[Regression Test] VALID")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
210
211
    parser.add_argument("--check_regression", action="store_true", default=False)
    parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
212
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
213

214
215
216
    # beta - test oss_ddp
    parser.add_argument("--oss_ddp", action="store_true", default=False)

217
    args = parser.parse_args()
218
    print(f"Benchmark arguments: {args}")
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    if args.oss_ddp:
        print("\nBenchmark OSS DDP")
        mp.spawn(
            train_oss_ddp,
            args=(args.world_size, args.epochs, args.batch_size, args.data_size),
            nprocs=args.world_size,
            join=True,
        )
    else:
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
            args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
            nprocs=args.world_size,
            join=True,
        )

        print("\nBenchmark OSS")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                True,
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
            ),
            nprocs=args.world_size,
            join=True,
        )