"src/vscode:/vscode.git/clone" did not exist on "ae4112d2bbfa363f2f3049daad54f0fb89c34499"
oss.py 13.7 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
9
import logging
import shutil
import tempfile
10
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
11
from typing import Any, List, Optional, cast
12

13
from golden_configs import oss_mnist
14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
21
from torch.nn.parallel import DistributedDataParallel as DDP
22
23
24
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
25
from torchvision.transforms import Compose, Resize, ToTensor
26

27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
OPTIM = torch.optim.RMSprop
32
TEMPDIR = tempfile.gettempdir()
33
34


35
def dist_init(rank, world_size, backend):
36
    logging.info(f"Using backend: {backend}")
37
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
38
39


40
def get_problem(rank, world_size, batch_size, device, model_name: str):
41
    # Select the desired model on the fly
42
    logging.info(f"Using {model_name} for benchmarking")
43
44
45
46
47

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
48

49
    # Data setup, duplicate the grey channels to get pseudo color
50
51
    def collate(inputs: List[Any]):
        return {
52
53
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
54
55
        }

56
57
58
59
60
61
62
63
64
65
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
66
67
68
69
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

70
    loss_fn = nn.CrossEntropyLoss()
71
72
73
    return model, dataloader, loss_fn


74
75
class OptimType(str, Enum):
    vanilla = "pytorch"
76
77
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
78
79
80
    everyone = "everyone"


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def validate_benchmark(measurements, args, check_regression):
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
        assert (median + 3.0 * mad) > golden_data["reference_speed"], "Speed regression detected"
        assert max_memory < 1.05 * golden_data["reference_memory"], "Memory use regression detected"
        assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")


112
113
def train(
    rank: int,
114
    args: argparse.Namespace,
115
    backend: str = "gloo",
116
    optim_type: OptimType = OptimType.vanilla,
117
118
    check_regression: bool = True,
):
119
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
120

121
    # DDP
122
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
123
124

    # Setup
125
126
127
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
128
129
130
131
132
133
134
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

135
    device = torch.device("cpu") if args.cpu else torch.device(rank)
136
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
137

138
    # Shard the optimizer
139
    optimizer: Optional[torch.optim.Optimizer] = None
140
    model = cast(nn.Module, model)
141
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
142

143
    if optim_type == OptimType.oss_sharded_ddp:
144
145
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
146
    else:
147
        device_ids = None if args.cpu else [rank]
148
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
149
150
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
151
            if optim_type == OptimType.oss_ddp
152
153
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
154
    optimizer = cast(torch.optim.Optimizer, optimizer)
155
156

    # Reset the memory use counter
157
    if not args.cpu:
158
        torch.cuda.empty_cache()
159
160
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
161

162
    # Standard training loop
163
164
165
166
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
167
    final_loss: Optional[float] = -1.0
168
    need_profiling = args.profile
169

170
    for epoch in range(args.epochs):
171
172
        n_items = 0
        epoch_runtime = 0.0
173
174

        for batch in dataloader:
175
176
            if not args.cpu:
                torch.cuda.synchronize(rank)
177
            batch_start = time.monotonic()
178

179
            def closure(data=batch, grad_scaler=None):
180
                model.zero_grad()
181
182
183
184
185
186
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
187
                if grad_scaler is not None:
188
189
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
190
191
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
192
193
194

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
195
                else:
196
197
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
198
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
199

200
201
202
203
204
205
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
206
207
                return loss

208
209
210
211
212
213
214
215
216
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

217
            if need_profiling and not args.cpu:
218
                logging.info("Profiling the run")
219
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
220
                    with profiler.record_function("batch"):
221
                        final_loss = run_closure(closure, scaler, optimizer)
222

223
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
224
225
226
                need_profiling = False  # only profile once

            else:
227
                final_loss = run_closure(closure, scaler, optimizer)
228

229
230
231
232
233
234
235
236
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

237
238
            n_items += args.batch_size

239
240
241
242
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

243
            batch_end = time.monotonic()
244
            epoch_runtime += batch_end - batch_start
245

246
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
247
248
249
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
250
            optimizer.consolidate_state_dict()
251
            if dist.get_rank() == 0:
252
                _ = optimizer.state_dict()
253
                logging.info("... State dict collected")
254

255
        measurements.append(n_items / epoch_runtime)
256
        if dist.get_rank() == 0:
257
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
258
259

    training_stop = time.monotonic()
260
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
261
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
262

263
    validate_benchmark(measurements, args, check_regression)
264

265
266
    dist.destroy_process_group()  # type: ignore

267
268
269
270
271
272
273

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
274
    parser.add_argument("--batch_size", action="store", default=256, type=int)
275
    parser.add_argument("--check_regression", action="store_true", default=False)
276
277
278
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
279
280
281
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
282
    parser.add_argument("--gloo", action="store_true", default=False)
283
    parser.add_argument("--profile", action="store_true", default=False)
284
    parser.add_argument("--cpu", action="store_true", default=False)
285
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
286
287
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
288
289

    args = parser.parse_args()
290

291
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
292
    logging.info("Benchmark arguments: %s" % args)
293

294
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
295

296
297
298
299
300
301
302
303
304
305
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

306
            logging.warning("Failed loading dataset: %s " % e)
307
308
309
310
311
312
313
314
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

315
    # Benchmark the different configurations, via multiple processes
316
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
317
        logging.info("\n*** Benchmark vanilla optimizer")
318
        mp.spawn(
319
320
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
321
322
323
324
            nprocs=args.world_size,
            join=True,
        )

325
326
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
327
        mp.spawn(
328
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
329
330
        )

331
332
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
333
        mp.spawn(
334
            train,  # type: ignore
335
336
            args=(
                args,
337
                BACKEND,
338
339
340
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
341
342
343
            nprocs=args.world_size,
            join=True,
        )