oss.py 8.37 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
from typing import Any, List, Optional, cast
8
9
10
11
12
13
14
15
16
17

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

18
from fairscale.nn.data_parallel import ShardedDataParallel
19
20
from fairscale.optim.oss import OSS

21
OPTIM = torch.optim.RMSprop
22
23


24
25
26
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
27
28


29
def get_problem(rank, data_size, batch_size):
30
31
32
33
34
35
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
36
37
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
38
39
40
41
42
43
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()
44
45
46
47
    return model, dataloader, loss_fn


def train_oss_ddp(
48
    rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
49
50
51
):

    # DDP
52
    dist_init(rank, world_size, backend)
53
54
55
56
57
58
59
60

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    ddp = ShardedDataParallel(
        module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
    )
    optimizer = ddp.optimizer
61

62
63
64
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
83
84
85

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                if dist.get_rank() == 0:
                    print(f"Loss: {loss.item()}")

                ddp.reduce()  # Send the gradients to the appropriate shards
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
121
    backend: str = "gloo",
122
123
124
125
126
127
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
):
    # DDP
128
    dist_init(rank, world_size, backend)
129
130
131
132

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

133
    # Shard the optimizer
134
135
136
137
138
139
140
141
    optimizer: torch.optim.Optimizer = (
        OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        if use_oss
        else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
    )

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
142
143
144
145
146
147
148

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
149
    final_loss: Optional[float] = -1.0
150
151
152
153
154
155
156
157
158
159
160
161

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
162
163
164

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

165
166
                return loss

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
167
            final_loss = optimizer.step(closure)
168
169

        epoch_end = time.monotonic()
170
171
172
173
174

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
175
            optimizer.consolidate_state_dict()
176
            if dist.get_rank() == 0:
177
                _ = optimizer.state_dict()
178
179
                print("... State dict collected")

180
        measurements.append(data_size / (epoch_end - epoch_start))
181
        if dist.get_rank() == 0:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
182
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")
183
184
185
186
187
188
189
190
191

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

192
193
194
195
196
197
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

198
    if use_oss and check_regression and dist.get_rank() == 0:
199
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
200
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
201
202
203
204
205
206
207
208
209
210
211
212
        print("[Regression Test] VALID")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
213
214
    parser.add_argument("--check_regression", action="store_true", default=False)
    parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
215
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
216
    parser.add_argument("--gloo", action="store_true", default=False)
217

218
219
220
    # beta - test oss_ddp
    parser.add_argument("--oss_ddp", action="store_true", default=False)

221
    args = parser.parse_args()
222
    print(f"Benchmark arguments: {args}")
223

224
    backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
225
226
227
228
    if args.oss_ddp:
        print("\nBenchmark OSS DDP")
        mp.spawn(
            train_oss_ddp,
229
            args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
230
231
232
233
234
235
236
            nprocs=args.world_size,
            join=True,
        )
    else:
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
237
            args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False),
238
239
240
241
242
243
244
245
246
247
248
249
            nprocs=args.world_size,
            join=True,
        )

        print("\nBenchmark OSS")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
250
                backend,
251
252
253
254
255
256
257
258
                True,
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
            ),
            nprocs=args.world_size,
            join=True,
        )