oss.py 13.5 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
import logging
import tempfile
9
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
10
from typing import Any, List, Optional, cast
11

12
from golden_configs import oss_mnist
13
import numpy as np
14
import torch
15
import torch.autograd.profiler as profiler
16
from torch.cuda.amp import GradScaler as TorchGradScaler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
from torchvision.transforms import Compose, Resize, ToTensor
25

26
from benchmarks.datasets.mnist import setup_cached_mnist
27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
TEMPDIR = tempfile.gettempdir()
32
33


34
def dist_init(rank, world_size, backend):
35
    logging.info(f"Using backend: {backend}")
36
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
37
38


39
def get_problem(rank, world_size, batch_size, device, model_name: str):
40
    # Select the desired model on the fly
41
    logging.info(f"Using {model_name} for benchmarking")
42
43
44
45
46

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
47

48
    # Data setup, duplicate the grey channels to get pseudo color
49
50
    def collate(inputs: List[Any]):
        return {
51
52
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
53
54
        }

55
56
57
58
59
60
61
62
63
64
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
65
66
67
68
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

69
    loss_fn = nn.CrossEntropyLoss()
70
71
72
    return model, dataloader, loss_fn


73
74
class OptimType(str, Enum):
    vanilla = "pytorch"
75
76
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
77
78
79
    everyone = "everyone"


80
def validate_benchmark(measurements, final_loss, args, check_regression):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
104
105
106
107
108
109
110
111
112
        assert median + 3.0 * mad > golden_data["reference_speed"], (
            f"Speed regression detected: " f"{median + 3.0 * mad} vs.  {golden_data['reference_speed']}"
        )
        assert max_memory < 1.05 * golden_data["reference_memory"], (
            f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
        )
        assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-3, (
            f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
        )
113
114
115
        logging.info("[Regression Test] VALID")


116
117
def train(
    rank: int,
118
    args: argparse.Namespace,
119
    backend: str = "gloo",
120
    optim_type: OptimType = OptimType.vanilla,
121
122
    check_regression: bool = True,
):
123
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
124

125
126
127
128
    use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

129
    # DDP
130
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
131
132

    # Setup
133
134
135
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
136
137
138
139
140
141
142
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

143
    device = torch.device("cpu") if args.cpu else torch.device(rank)
144
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
145

146
    # Shard the optimizer
147
    optimizer: Optional[torch.optim.Optimizer] = None
148
    model = cast(nn.Module, model)
149
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
150

151
    if optim_type == OptimType.oss_sharded_ddp:
152
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
153
154
        # Single node run typically, no need for reduce buckets
        model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
155
    else:
156
        device_ids = None if args.cpu else [rank]
157
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
158
159
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
160
            if optim_type == OptimType.oss_ddp
161
162
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
163
    optimizer = cast(torch.optim.Optimizer, optimizer)
164
165

    # Reset the memory use counter
166
    if not args.cpu:
167
        torch.cuda.empty_cache()
168
169
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
170

171
    # Standard training loop
172
173
174
175
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
176
    final_loss: Optional[float] = -1.0
177
    need_profiling = args.profile
178

179
    for epoch in range(args.epochs):
180
181
        n_items = 0
        epoch_runtime = 0.0
182
183

        for batch in dataloader:
184
185
            if not args.cpu:
                torch.cuda.synchronize(rank)
186
            batch_start = time.monotonic()
187

188
            def closure(data=batch, grad_scaler=None):
189
                model.zero_grad()
190
191
192
193
194
195
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
196
                if grad_scaler is not None:
197
198
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
199
200
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
201
202
203

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
204
                else:
205
206
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
207
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
208

209
210
211
212
213
214
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
215
216
                return loss

217
218
219
220
221
222
223
224
225
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

226
            if need_profiling and not args.cpu:
227
                logging.info("Profiling the run")
228
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
229
                    with profiler.record_function("batch"):
230
                        final_loss = run_closure(closure, scaler, optimizer)
231

232
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
233
234
235
                need_profiling = False  # only profile once

            else:
236
                final_loss = run_closure(closure, scaler, optimizer)
237

238
239
240
241
242
243
244
245
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

246
247
            n_items += args.batch_size

248
249
250
251
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

252
            batch_end = time.monotonic()
253
            epoch_runtime += batch_end - batch_start
254

255
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
256
257
258
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
259
            optimizer.consolidate_state_dict()
260
            if dist.get_rank() == 0:
261
                _ = optimizer.state_dict()
262
                logging.info("... State dict collected")
263

264
        measurements.append(n_items / epoch_runtime)
265
        if dist.get_rank() == 0:
266
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
267
268

    training_stop = time.monotonic()
269
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
270
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
271

272
    validate_benchmark(measurements, final_loss, args, check_regression)
273

274
275
    dist.destroy_process_group()  # type: ignore

276
277
278
279
280
281
282

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
283
    parser.add_argument("--batch_size", action="store", default=256, type=int)
284
    parser.add_argument("--check_regression", action="store_true", default=False)
285
286
287
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
288
    parser.add_argument("--gloo", action="store_true", default=False)
289
    parser.add_argument("--profile", action="store_true", default=False)
290
    parser.add_argument("--cpu", action="store_true", default=False)
291
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
292
293
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
294
295
296
    parser.add_argument(
        "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
    )
297
298

    args = parser.parse_args()
299

300
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
301
    logging.info("Benchmark arguments: %s" % args)
302

303
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
304

305
    # Download dataset once for all processes
306
    setup_cached_mnist()
307

308
    # Benchmark the different configurations, via multiple processes
309
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
310
        logging.info("\n*** Benchmark vanilla optimizer")
311
        mp.spawn(
312
313
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
314
315
316
317
            nprocs=args.world_size,
            join=True,
        )

318
319
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
320
        mp.spawn(
321
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
322
323
        )

324
325
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
326
        mp.spawn(
327
            train,  # type: ignore
328
            args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
329
330
331
            nprocs=args.world_size,
            join=True,
        )