oss.py 8.19 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
from typing import Any, List, Optional, cast
8
9
10
11
12
13
14
15
16
17

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

18
from fairscale.nn.data_parallel import ShardedDataParallel
19
20
21
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
22
OPTIM = torch.optim.RMSprop
23
24
25


def dist_init(rank, world_size):
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
26
27
28
    dist.init_process_group(
        backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
    )
29
30


31
def get_problem(rank, data_size, batch_size):
32
33
34
35
36
37
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
38
39
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
40
41
42
43
44
45
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    return model, dataloader, loss_fn


def train_oss_ddp(
    rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
):

    # DDP
    dist_init(rank, world_size)

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    ddp = ShardedDataParallel(
        module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
    )
    optimizer = ddp.optimizer
63

64
65
66
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
85
86
87

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
                if dist.get_rank() == 0:
                    print(f"Loss: {loss.item()}")

                ddp.reduce()  # Send the gradients to the appropriate shards
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
):
    # DDP
    dist_init(rank, world_size)

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

134
    # Shard the optimizer
135
136
137
138
139
140
141
142
    optimizer: torch.optim.Optimizer = (
        OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        if use_oss
        else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
    )

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
143
144
145
146
147
148
149

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
150
    final_loss: Optional[float] = -1.0
151
152
153
154
155
156
157
158
159
160
161
162

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
163
164
165

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

166
167
                return loss

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
168
            final_loss = optimizer.step(closure)
169
170

        epoch_end = time.monotonic()
171
172
173
174
175

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
176
            optimizer.consolidate_state_dict()
177
            if dist.get_rank() == 0:
178
                _ = optimizer.state_dict()
179
180
                print("... State dict collected")

181
        measurements.append(data_size / (epoch_end - epoch_start))
182
        if dist.get_rank() == 0:
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
183
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")
184
185
186
187
188
189
190
191
192

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

193
194
195
196
197
198
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

199
    if use_oss and check_regression and dist.get_rank() == 0:
200
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
201
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
202
203
204
205
206
207
208
209
210
211
212
213
        print("[Regression Test] VALID")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
214
215
    parser.add_argument("--check_regression", action="store_true", default=False)
    parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
216
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
217

218
219
220
    # beta - test oss_ddp
    parser.add_argument("--oss_ddp", action="store_true", default=False)

221
    args = parser.parse_args()
222
    print(f"Benchmark arguments: {args}")
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    if args.oss_ddp:
        print("\nBenchmark OSS DDP")
        mp.spawn(
            train_oss_ddp,
            args=(args.world_size, args.epochs, args.batch_size, args.data_size),
            nprocs=args.world_size,
            join=True,
        )
    else:
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
            args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
            nprocs=args.world_size,
            join=True,
        )

        print("\nBenchmark OSS")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                True,
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
            ),
            nprocs=args.world_size,
            join=True,
        )