oss.py 13.3 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
9
import logging
import shutil
import tempfile
10
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
11
from typing import Any, List, Optional, cast
12

13
import numpy as np
14
import torch
15
import torch.autograd.profiler as profiler
16
from torch.cuda.amp import GradScaler as TorchGradScaler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
from torchvision.transforms import Compose, Resize, ToTensor
25

26
27
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
28
from fairscale.optim.grad_scaler import ShardedGradScaler
29

30
OPTIM = torch.optim.RMSprop
31
TEMPDIR = tempfile.gettempdir()
32
33


34
def dist_init(rank, world_size, backend):
35
    logging.info(f"Using backend: {backend}")
36
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
37
38


39
def get_problem(rank, world_size, batch_size, device, model_name: str):
40
    # Select the desired model on the fly
41
    logging.info(f"Using {model_name} for benchmarking")
42
43
44
45
46

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
47

48
    # Data setup, duplicate the grey channels to get pseudo color
49
50
    def collate(inputs: List[Any]):
        return {
51
52
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
53
54
        }

55
56
57
58
59
60
61
62
63
64
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
65
66
67
68
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

69
    loss_fn = nn.CrossEntropyLoss()
70
71
72
    return model, dataloader, loss_fn


73
74
class OptimType(str, Enum):
    vanilla = "pytorch"
75
76
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
77
78
79
    everyone = "everyone"


80
81
def train(
    rank: int,
82
    args: argparse.Namespace,
83
    backend: str = "gloo",
84
    optim_type: OptimType = OptimType.vanilla,
85
86
    check_regression: bool = True,
):
87
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
88

89
    # DDP
90
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
91
92

    # Setup
93
94
95
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
96
97
98
99
100
101
102
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

103
    device = torch.device("cpu") if args.cpu else torch.device(rank)
104
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
105

106
    # Shard the optimizer
107
    optimizer: Optional[torch.optim.Optimizer] = None
108
    model = cast(nn.Module, model)
109
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
110

111
    if optim_type == OptimType.oss_sharded_ddp:
112
113
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
114
    else:
115
        device_ids = None if args.cpu else [rank]
116
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
117
118
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
119
            if optim_type == OptimType.oss_ddp
120
121
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
122
    optimizer = cast(torch.optim.Optimizer, optimizer)
123
124

    # Reset the memory use counter
125
    if not args.cpu:
126
        torch.cuda.empty_cache()
127
128
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
129

130
    # Standard training loop
131
132
133
134
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
135
    final_loss: Optional[float] = -1.0
136
    need_profiling = args.profile
137

138
    for epoch in range(args.epochs):
139
140
        n_items = 0
        epoch_runtime = 0.0
141
142

        for batch in dataloader:
143
144
            if not args.cpu:
                torch.cuda.synchronize(rank)
145
            batch__start = time.monotonic()
146

147
            def closure(data=batch, grad_scaler=None):
148
                model.zero_grad()
149
150
151
152
153
154
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
155
                if grad_scaler is not None:
156
157
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
158
159
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
160
161
162

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
163
                else:
164
165
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
166
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
167

168
169
170
171
172
173
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
174
175
                return loss

176
            if need_profiling and not args.cpu:
177
                logging.info("Profiling the run")
178
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
179
                    with profiler.record_function("batch"):
180
181
182
183
184
185
                        if scaler is not None:
                            final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)
186

187
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
188
189
190
                need_profiling = False  # only profile once

            else:
191
192
193
194
195
196
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)
197

198
199
200
201
202
203
204
205
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

206
207
            n_items += args.batch_size

208
209
210
211
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

212
213
            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
214

215
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
216
217
218
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
219
            optimizer.consolidate_state_dict()
220
            if dist.get_rank() == 0:
221
                _ = optimizer.state_dict()
222
                logging.info("... State dict collected")
223

224
        measurements.append(n_items / epoch_runtime)
225
        if dist.get_rank() == 0:
226
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
227

228
    max_memory = -1.0
229
230
    if not args.cpu:
        torch.cuda.synchronize(rank)
231
232
233
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

234
    training_stop = time.monotonic()
235
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
236
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
237

238
239
240
241
242
243
244
245
246
    # Compute the median and median of absolute differences img per second
    measurements.sort()
    median = measurements[len(measurements) // 2]

    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    logging.info(f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}")
247

248
    if check_regression and dist.get_rank() == 0:
249
        assert (median + 3.0 * mad) > args.reference_speed, "Speed regression detected"
250
251
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
252

253
        logging.info("[Regression Test] VALID")
254

255
256
    dist.destroy_process_group()  # type: ignore

257
258
259
260
261
262
263

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
264
    parser.add_argument("--batch_size", action="store", default=256, type=int)
265
    parser.add_argument("--check_regression", action="store_true", default=False)
266
267
268
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
269
270
271
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
272
    parser.add_argument("--gloo", action="store_true", default=False)
273
    parser.add_argument("--profile", action="store_true", default=False)
274
    parser.add_argument("--cpu", action="store_true", default=False)
275
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
276
277
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
278
279

    args = parser.parse_args()
280

281
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
282
    logging.info("Benchmark arguments: %s" % args)
283

284
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
285

286
287
288
289
290
291
292
293
294
295
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

296
            logging.warning("Failed loading dataset: %s " % e)
297
298
299
300
301
302
303
304
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

305
    # Benchmark the different configurations, via multiple processes
306
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
307
        logging.info("\n*** Benchmark vanilla optimizer")
308
        mp.spawn(
309
310
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
311
312
313
314
            nprocs=args.world_size,
            join=True,
        )

315
316
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
317
        mp.spawn(
318
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
319
320
        )

321
322
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
323
        mp.spawn(
324
            train,  # type: ignore
325
326
            args=(
                args,
327
                BACKEND,
328
329
330
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
331
332
333
            nprocs=args.world_size,
            join=True,
        )