"torchvision/vscode:/vscode.git/clone" did not exist on "8eb6f887933517d8c8e2b18ee00e8e606a07d053"
oss.py 12.8 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
9
import logging
import shutil
import tempfile
10
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
11
from typing import Any, List, Optional, cast
12

13
import numpy as np
14
import torch
15
import torch.autograd.profiler as profiler
16
from torch.cuda.amp import GradScaler as TorchGradScaler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
25
from torchvision.transforms import ToTensor

26
27
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
28
from fairscale.optim.grad_scaler import ShardedGradScaler
29

30
OPTIM = torch.optim.RMSprop
31
TEMPDIR = tempfile.gettempdir()
32
33


34
def dist_init(rank, world_size, backend):
35
    logging.info(f"Using backend: {backend}")
36
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
37
38


39
def get_problem(rank, world_size, batch_size, device, model_name: str):
40
    # Select the desired model on the fly
41
    logging.info(f"Using {model_name} for benchmarking")
42
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
43

44
    # Data setup, duplicate the grey channels to get pseudo color
45
46
    def collate(inputs: List[Any]):
        return {
47
48
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
49
50
        }

51
52
53
54
55
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

56
    loss_fn = nn.CrossEntropyLoss()
57
58
59
    return model, dataloader, loss_fn


60
61
class OptimType(str, Enum):
    vanilla = "pytorch"
62
63
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
64
65
66
    everyone = "everyone"


67
68
def train(
    rank: int,
69
    args: argparse.Namespace,
70
    backend: str = "gloo",
71
    optim_type: OptimType = OptimType.vanilla,
72
73
    check_regression: bool = True,
):
74
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
75

76
    # DDP
77
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
78
79

    # Setup
80
81
82
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
83
84
85
86
87
88
89
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

90
    device = torch.device("cpu") if args.cpu else torch.device(rank)
91
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
92

93
    # Shard the optimizer
94
    optimizer: Optional[torch.optim.Optimizer] = None
95
    model = cast(nn.Module, model)
96
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
97

98
    if optim_type == OptimType.oss_sharded_ddp:
99
100
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
101
    else:
102
        device_ids = None if args.cpu else [rank]
103
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
104
105
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
106
            if optim_type == OptimType.oss_ddp
107
108
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
109
    optimizer = cast(torch.optim.Optimizer, optimizer)
110
111

    # Reset the memory use counter
112
    if not args.cpu:
113
        torch.cuda.empty_cache()
114
115
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
116

117
    # Standard training loop
118
119
120
121
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
122
    final_loss: Optional[float] = -1.0
123
    need_profiling = args.profile
124

125
    for epoch in range(args.epochs):
126
127
        n_items = 0
        epoch_runtime = 0.0
128
129

        for batch in dataloader:
130
131
            if not args.cpu:
                torch.cuda.synchronize(rank)
132
            batch__start = time.monotonic()
133

134
            def closure(data=batch, grad_scaler=None):
135
                model.zero_grad()
136
137
138
139
140
141
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
142
                if grad_scaler is not None:
143
144
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
145
146
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
147
148
149

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
150
                else:
151
152
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
153
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
154

155
156
157
158
159
160
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
161
162
                return loss

163
            if need_profiling and not args.cpu:
164
                logging.info("Profiling the run")
165
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
166
                    with profiler.record_function("batch"):
167
168
169
170
171
172
                        if scaler is not None:
                            final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)
173

174
                        prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
175
176
177
178

                need_profiling = False  # only profile once

            else:
179
180
181
182
183
184
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)
185

186
187
188
189
190
191
192
193
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

194
195
            n_items += args.batch_size

196
197
198
199
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

200
201
            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
202

203
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
204
205
206
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
207
            optimizer.consolidate_state_dict()
208
            if dist.get_rank() == 0:
209
                _ = optimizer.state_dict()
210
                logging.info("... State dict collected")
211

212
        measurements.append(n_items / epoch_runtime)
213
        if dist.get_rank() == 0:
214
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
215

216
    max_memory = -1.0
217
218
    if not args.cpu:
        torch.cuda.synchronize(rank)
219
220
221
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

222
    training_stop = time.monotonic()
223
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
224
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
225

226
227
228
229
230
231
232
233
234
    # Compute the median and median of absolute differences img per second
    measurements.sort()
    median = measurements[len(measurements) // 2]

    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    logging.info(f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}")
235

236
    if check_regression and dist.get_rank() == 0:
237
        assert (median + 3.0 * mad) > args.reference_speed, "Speed regression detected"
238
239
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
240

241
        logging.info("[Regression Test] VALID")
242

243
244
    dist.destroy_process_group()  # type: ignore

245
246
247
248
249
250
251

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
252
    parser.add_argument("--batch_size", action="store", default=256, type=int)
253
    parser.add_argument("--check_regression", action="store_true", default=False)
254
255
256
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
257
258
259
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
260
    parser.add_argument("--gloo", action="store_true", default=False)
261
    parser.add_argument("--profile", action="store_true", default=False)
262
    parser.add_argument("--cpu", action="store_true", default=False)
263
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
264
265
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
266
267

    args = parser.parse_args()
268

269
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
270
    logging.info("Benchmark arguments: %s" % args)
271

272
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
273

274
275
276
277
278
279
280
281
282
283
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

284
            logging.warning("Failed loading dataset: %s " % e)
285
286
287
288
289
290
291
292
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

293
    # Benchmark the different configurations, via multiple processes
294
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
295
        logging.info("\n*** Benchmark vanilla optimizer")
296
297
        mp.spawn(
            train,
298
            args=(args, BACKEND, OptimType.vanilla, False,),  # no regression check
299
300
301
302
            nprocs=args.world_size,
            join=True,
        )

303
304
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
305
        mp.spawn(
306
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
307
308
        )

309
310
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
311
312
        mp.spawn(
            train,
313
314
            args=(
                args,
315
                BACKEND,
316
317
318
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
319
320
321
            nprocs=args.world_size,
            join=True,
        )