oss.py 13.7 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
9
import logging
import shutil
import tempfile
10
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
11
from typing import Any, List, Optional, cast
12

13
from golden_configs import oss_mnist
14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
21
from torch.nn.parallel import DistributedDataParallel as DDP
22
23
24
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
25
from torchvision.transforms import Compose, Resize, ToTensor
26

27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
TEMPDIR = tempfile.gettempdir()
32
33


34
def dist_init(rank, world_size, backend):
35
    logging.info(f"Using backend: {backend}")
36
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
37
38


39
def get_problem(rank, world_size, batch_size, device, model_name: str):
40
    # Select the desired model on the fly
41
    logging.info(f"Using {model_name} for benchmarking")
42
43
44
45
46

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
47

48
    # Data setup, duplicate the grey channels to get pseudo color
49
50
    def collate(inputs: List[Any]):
        return {
51
52
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
53
54
        }

55
56
57
58
59
60
61
62
63
64
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
65
66
67
68
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

69
    loss_fn = nn.CrossEntropyLoss()
70
71
72
    return model, dataloader, loss_fn


73
74
class OptimType(str, Enum):
    vanilla = "pytorch"
75
76
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
77
78
79
    everyone = "everyone"


80
def validate_benchmark(measurements, final_loss, args, check_regression):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
        assert (median + 3.0 * mad) > golden_data["reference_speed"], "Speed regression detected"
        assert max_memory < 1.05 * golden_data["reference_memory"], "Memory use regression detected"
        assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")


111
112
def train(
    rank: int,
113
    args: argparse.Namespace,
114
    backend: str = "gloo",
115
    optim_type: OptimType = OptimType.vanilla,
116
117
    check_regression: bool = True,
):
118
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
119

120
121
122
123
    use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

124
    # DDP
125
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
126
127

    # Setup
128
129
130
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
131
132
133
134
135
136
137
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

138
    device = torch.device("cpu") if args.cpu else torch.device(rank)
139
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
140

141
    # Shard the optimizer
142
    optimizer: Optional[torch.optim.Optimizer] = None
143
    model = cast(nn.Module, model)
144
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
145

146
    if optim_type == OptimType.oss_sharded_ddp:
147
148
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
149
    else:
150
        device_ids = None if args.cpu else [rank]
151
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
152
153
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
154
            if optim_type == OptimType.oss_ddp
155
156
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
157
    optimizer = cast(torch.optim.Optimizer, optimizer)
158
159

    # Reset the memory use counter
160
    if not args.cpu:
161
        torch.cuda.empty_cache()
162
163
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
164

165
    # Standard training loop
166
167
168
169
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
170
    final_loss: Optional[float] = -1.0
171
    need_profiling = args.profile
172

173
    for epoch in range(args.epochs):
174
175
        n_items = 0
        epoch_runtime = 0.0
176
177

        for batch in dataloader:
178
179
            if not args.cpu:
                torch.cuda.synchronize(rank)
180
            batch_start = time.monotonic()
181

182
            def closure(data=batch, grad_scaler=None):
183
                model.zero_grad()
184
185
186
187
188
189
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
190
                if grad_scaler is not None:
191
192
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
193
194
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
195
196
197

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
198
                else:
199
200
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
201
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
202

203
204
205
206
207
208
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
209
210
                return loss

211
212
213
214
215
216
217
218
219
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

220
            if need_profiling and not args.cpu:
221
                logging.info("Profiling the run")
222
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
223
                    with profiler.record_function("batch"):
224
                        final_loss = run_closure(closure, scaler, optimizer)
225

226
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
227
228
229
                need_profiling = False  # only profile once

            else:
230
                final_loss = run_closure(closure, scaler, optimizer)
231

232
233
234
235
236
237
238
239
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

240
241
            n_items += args.batch_size

242
243
244
245
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

246
            batch_end = time.monotonic()
247
            epoch_runtime += batch_end - batch_start
248

249
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
250
251
252
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
253
            optimizer.consolidate_state_dict()
254
            if dist.get_rank() == 0:
255
                _ = optimizer.state_dict()
256
                logging.info("... State dict collected")
257

258
        measurements.append(n_items / epoch_runtime)
259
        if dist.get_rank() == 0:
260
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
261
262

    training_stop = time.monotonic()
263
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
264
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
265

266
    validate_benchmark(measurements, final_loss, args, check_regression)
267

268
269
    dist.destroy_process_group()  # type: ignore

270
271
272
273
274
275
276

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
277
    parser.add_argument("--batch_size", action="store", default=256, type=int)
278
    parser.add_argument("--check_regression", action="store_true", default=False)
279
280
281
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
282
    parser.add_argument("--gloo", action="store_true", default=False)
283
    parser.add_argument("--profile", action="store_true", default=False)
284
    parser.add_argument("--cpu", action="store_true", default=False)
285
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
286
287
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
288
289
290
    parser.add_argument(
        "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
    )
291
292

    args = parser.parse_args()
293

294
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
295
    logging.info("Benchmark arguments: %s" % args)
296

297
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
298

299
300
301
302
303
304
305
306
307
308
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

309
            logging.warning("Failed loading dataset: %s " % e)
310
311
312
313
314
315
316
317
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

318
    # Benchmark the different configurations, via multiple processes
319
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
320
        logging.info("\n*** Benchmark vanilla optimizer")
321
        mp.spawn(
322
323
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
324
325
326
327
            nprocs=args.world_size,
            join=True,
        )

328
329
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
330
        mp.spawn(
331
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
332
333
        )

334
335
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
336
        mp.spawn(
337
            train,  # type: ignore
338
            args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
339
340
341
            nprocs=args.world_size,
            join=True,
        )