oss.py 9.23 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
7
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
from typing import Any, List, Optional, cast
9

10
import numpy as np
11
import torch
12
import torch.autograd.profiler as profiler
13
14
15
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
16
from torch.nn.parallel import DistributedDataParallel as DDP
17
18
19
20
21
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

22
from fairscale.nn.data_parallel import ShardedDataParallel
23
24
from fairscale.optim.oss import OSS

25
OPTIM = torch.optim.RMSprop
26
27


28
29
30
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
31
32


33
def get_problem(rank, data_size, batch_size):
34
35
36
37
38
39
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
40
41
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
42
43
44
        }

    dataloader = DataLoader(
45
46
47
        dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
        batch_size=batch_size,
        collate_fn=collate,
48
49
    )
    loss_fn = nn.CrossEntropyLoss()
50
51
52
    return model, dataloader, loss_fn


53
54
55
56
57
58
59
class OptimType(str, Enum):
    vanilla = "pytorch"
    oss = "oss"
    oss_sdp = "oss_sdp"
    everyone = "everyone"


60
61
62
63
64
65
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
66
    backend: str = "gloo",
67
68
    optim_type: OptimType = OptimType.vanilla,
    profile: bool = False,
69
70
71
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
72
    reference_loss: float = -1.0,
73
74
):
    # DDP
75
    dist_init(rank=rank, world_size=world_size, backend=backend)
76
77

    # Setup
78
79
80
81
82
83
84
85
86
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

87
88
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

89
    # Shard the optimizer
90
91
    optimizer: Optional[torch.optim.Optimizer] = None

92
    if optim_type == OptimType.oss_sdp:
93
        ddp = ShardedDataParallel(
94
95
96
97
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
            world_size=world_size,
98
            broadcast_buffers=True,
99
100
101
102
103
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
104
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # type: ignore
105
106
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
107
            if optim_type == OptimType.oss
108
109
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
110
111
112

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
113
114
115
116
117
118
119

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
120
    final_loss: Optional[float] = -1.0
121
    need_profiling = profile
122
123
124
125
126
127
128
129
130
131
132
133

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
134

135
                if optim_type == OptimType.oss_sdp:
136
137
                    ddp.reduce()  # Send the gradients to the appropriate shards

138
139
                return loss

140
141
142
143
144
145
146
147
148
149
150
151
152
153
            if need_profiling:
                print("Profiling the run")
                with profiler.profile(use_cuda=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
                        print("profiling done, final loss ", cast(float, final_loss))

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
154
155

        epoch_end = time.monotonic()
156

157
        if optim_type == OptimType.oss:
158
159
160
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
161
            optimizer.consolidate_state_dict()
162
            if dist.get_rank() == 0:
163
                _ = optimizer.state_dict()
164
165
                print("... State dict collected")

166
        measurements.append(data_size / (epoch_end - epoch_start))
167
        if dist.get_rank() == 0:
168
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
169
170
171
172
173
174
175
176
177

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

178
179
180
181
182
183
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

184
    if check_regression and dist.get_rank() == 0:
185
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
186
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
187
188
        assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"

189
190
        print("[Regression Test] VALID")

191
192
    dist.destroy_process_group()  # type: ignore

193
194
195
196
197
198
199
200
201

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
202
    parser.add_argument("--check_regression", action="store_true", default=False)
203
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
204
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
205
206
207
208
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
209
    parser.add_argument("--gloo", action="store_true", default=False)
210
    parser.add_argument("--profile", action="store_true", default=False)
211
212

    args = parser.parse_args()
213
    print(f"Benchmark arguments: {args}")
214

215
    backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
216
217

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
218
219
220
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
221
222
223
224
225
226
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
227
228
                OptimType.vanilla,
                args.profile,
229
230
                False,  # no regression check
            ),
231
232
233
234
            nprocs=args.world_size,
            join=True,
        )

235
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
236
        print("\nBenchmark OSS with DDP")
237
238
239
240
241
242
243
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
244
                backend,
245
246
                OptimType.oss,
                args.profile,
247
248
249
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
250
251
252
253
254
255
256
                args.reference_loss,
            ),
            nprocs=args.world_size,
            join=True,
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
257
        print("\nBenchmark OSS with SDP")
258
259
260
261
262
263
264
265
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
266
267
                OptimType.oss_sdp,
                args.profile,
268
                False,  # FIXME: @lefaudeux - SDP should give the same results
269
270
271
                -1,  # Not checking SDP for speed regression for now, still slower than OSS
                args.reference_memory,
                args.reference_loss,
272
273
274
275
            ),
            nprocs=args.world_size,
            join=True,
        )