oss.py 14 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
9
import logging
import shutil
import tempfile
10
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
11
from typing import Any, List, Optional, cast
12

13
from golden_configs import oss_mnist
14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
21
from torch.nn.parallel import DistributedDataParallel as DDP
22
23
24
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
25
from torchvision.transforms import Compose, Resize, ToTensor
26

27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
TEMPDIR = tempfile.gettempdir()
32
33


34
def dist_init(rank, world_size, backend):
35
    logging.info(f"Using backend: {backend}")
36
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
37
38


39
def get_problem(rank, world_size, batch_size, device, model_name: str):
40
    # Select the desired model on the fly
41
    logging.info(f"Using {model_name} for benchmarking")
42
43
44
45
46

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
47

48
    # Data setup, duplicate the grey channels to get pseudo color
49
50
    def collate(inputs: List[Any]):
        return {
51
52
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
53
54
        }

55
56
57
58
59
60
61
62
63
64
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
65
66
67
68
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

69
    loss_fn = nn.CrossEntropyLoss()
70
71
72
    return model, dataloader, loss_fn


73
74
class OptimType(str, Enum):
    vanilla = "pytorch"
75
76
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
77
78
79
    everyone = "everyone"


80
def validate_benchmark(measurements, final_loss, args, check_regression):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
104
105
106
107
108
109
110
111
112
        assert median + 3.0 * mad > golden_data["reference_speed"], (
            f"Speed regression detected: " f"{median + 3.0 * mad} vs.  {golden_data['reference_speed']}"
        )
        assert max_memory < 1.05 * golden_data["reference_memory"], (
            f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
        )
        assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, (
            f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
        )
113
114
115
        logging.info("[Regression Test] VALID")


116
117
def train(
    rank: int,
118
    args: argparse.Namespace,
119
    backend: str = "gloo",
120
    optim_type: OptimType = OptimType.vanilla,
121
122
    check_regression: bool = True,
):
123
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
124

125
126
127
128
    use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

129
    # DDP
130
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
131
132

    # Setup
133
134
135
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
136
137
138
139
140
141
142
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

143
    device = torch.device("cpu") if args.cpu else torch.device(rank)
144
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
145

146
    # Shard the optimizer
147
    optimizer: Optional[torch.optim.Optimizer] = None
148
    model = cast(nn.Module, model)
149
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
150

151
    if optim_type == OptimType.oss_sharded_ddp:
152
153
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
154
    else:
155
        device_ids = None if args.cpu else [rank]
156
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
157
158
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
159
            if optim_type == OptimType.oss_ddp
160
161
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
162
    optimizer = cast(torch.optim.Optimizer, optimizer)
163
164

    # Reset the memory use counter
165
    if not args.cpu:
166
        torch.cuda.empty_cache()
167
168
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
169

170
    # Standard training loop
171
172
173
174
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
175
    final_loss: Optional[float] = -1.0
176
    need_profiling = args.profile
177

178
    for epoch in range(args.epochs):
179
180
        n_items = 0
        epoch_runtime = 0.0
181
182

        for batch in dataloader:
183
184
            if not args.cpu:
                torch.cuda.synchronize(rank)
185
            batch_start = time.monotonic()
186

187
            def closure(data=batch, grad_scaler=None):
188
                model.zero_grad()
189
190
191
192
193
194
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
195
                if grad_scaler is not None:
196
197
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
198
199
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
200
201
202

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
203
                else:
204
205
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
206
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
207

208
209
210
211
212
213
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
214
215
                return loss

216
217
218
219
220
221
222
223
224
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

225
            if need_profiling and not args.cpu:
226
                logging.info("Profiling the run")
227
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
228
                    with profiler.record_function("batch"):
229
                        final_loss = run_closure(closure, scaler, optimizer)
230

231
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
232
233
234
                need_profiling = False  # only profile once

            else:
235
                final_loss = run_closure(closure, scaler, optimizer)
236

237
238
239
240
241
242
243
244
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

245
246
            n_items += args.batch_size

247
248
249
250
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

251
            batch_end = time.monotonic()
252
            epoch_runtime += batch_end - batch_start
253

254
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
255
256
257
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
258
            optimizer.consolidate_state_dict()
259
            if dist.get_rank() == 0:
260
                _ = optimizer.state_dict()
261
                logging.info("... State dict collected")
262

263
        measurements.append(n_items / epoch_runtime)
264
        if dist.get_rank() == 0:
265
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
266
267

    training_stop = time.monotonic()
268
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
269
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
270

271
    validate_benchmark(measurements, final_loss, args, check_regression)
272

273
274
    dist.destroy_process_group()  # type: ignore

275
276
277
278
279
280
281

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
282
    parser.add_argument("--batch_size", action="store", default=256, type=int)
283
    parser.add_argument("--check_regression", action="store_true", default=False)
284
285
286
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
287
    parser.add_argument("--gloo", action="store_true", default=False)
288
    parser.add_argument("--profile", action="store_true", default=False)
289
    parser.add_argument("--cpu", action="store_true", default=False)
290
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
291
292
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
293
294
295
    parser.add_argument(
        "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
    )
296
297

    args = parser.parse_args()
298

299
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
300
    logging.info("Benchmark arguments: %s" % args)
301

302
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
303

304
305
306
307
308
309
310
311
312
313
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

314
            logging.warning("Failed loading dataset: %s " % e)
315
316
317
318
319
320
321
322
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

323
    # Benchmark the different configurations, via multiple processes
324
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
325
        logging.info("\n*** Benchmark vanilla optimizer")
326
        mp.spawn(
327
328
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
329
330
331
332
            nprocs=args.world_size,
            join=True,
        )

333
334
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
335
        mp.spawn(
336
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
337
338
        )

339
340
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
341
        mp.spawn(
342
            train,  # type: ignore
343
            args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
344
345
346
            nprocs=args.world_size,
            join=True,
        )