oss.py 14.1 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
5
6

import argparse
7
from enum import Enum
8
import importlib
9
10
import logging
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
from golden_configs import oss_mnist
15
import numpy as np
16
import torch
17
import torch.autograd.profiler as profiler
18
from torch.cuda.amp import GradScaler as TorchGradScaler
19
20
21
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
22
from torch.nn.parallel import DistributedDataParallel as DDP
23
24
25
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
26
from torchvision.transforms import Compose, Resize, ToTensor
27

28
from benchmarks.datasets.mnist import setup_cached_mnist
29
30
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
31
from fairscale.optim.grad_scaler import ShardedGradScaler
32

33
TEMPDIR = tempfile.gettempdir()
34
35


36
def dist_init(rank, world_size, backend):
37
    logging.info(f"Using backend: {backend}")
38
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
39
40


41
def get_problem(rank, world_size, batch_size, device, model_name: str):
42
    # Select the desired model on the fly
43
    logging.info(f"Using {model_name} for benchmarking")
44
45
46
47
48

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
49

50
    # Data setup, duplicate the grey channels to get pseudo color
51
52
    def collate(inputs: List[Any]):
        return {
53
54
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
55
56
        }

57
58
59
60
61
62
63
64
65
66
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
67
68
69
70
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

71
    loss_fn = nn.CrossEntropyLoss()
72
73
74
    return model, dataloader, loss_fn


75
76
class OptimType(str, Enum):
    vanilla = "pytorch"
77
78
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
79
80
81
    everyone = "everyone"


82
def validate_benchmark(measurements, final_loss, args, check_regression):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
106
107
        assert median + 8.0 * mad > golden_data["reference_speed"], (
            f"Speed regression detected: " f"{median + 8.0 * mad} vs.  {golden_data['reference_speed']}"
108
109
110
111
        )
        assert max_memory < 1.05 * golden_data["reference_memory"], (
            f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
        )
112
113
        # any min_loss < than golden + epsilon is OK.
        assert cast(float, final_loss) - golden_data["reference_loss"] < 1e-2, (
114
115
            f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
        )
116
117
118
        logging.info("[Regression Test] VALID")


119
120
def train(
    rank: int,
121
    args: argparse.Namespace,
122
    backend: str = "gloo",
123
    optim_type: OptimType = OptimType.vanilla,
124
125
    check_regression: bool = True,
):
126
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
127

128
129
130
131
    use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

132
    # DDP
133
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
134
135

    # Setup
136
137
138
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
139
140
141
142
143
144
145
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

146
    device = torch.device("cpu") if args.cpu else torch.device(rank)
147
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
148

149
    # Shard the optimizer
150
    optimizer: Optional[torch.optim.Optimizer] = None
151
    model = cast(nn.Module, model)
152
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
153

154
    if optim_type == OptimType.oss_sharded_ddp:
155
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
156
157
        # Single node run typically, no need for reduce buckets
        model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
158
    else:
159
        device_ids = None if args.cpu else [rank]
160
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
161
162
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
163
            if optim_type == OptimType.oss_ddp
164
165
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
166
    optimizer = cast(torch.optim.Optimizer, optimizer)
167
168

    # Reset the memory use counter
169
    if not args.cpu:
170
        torch.cuda.empty_cache()
171
172
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
173

174
    # Standard training loop
175
176
177
178
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
179
    final_loss: Optional[float] = -1.0
180
    min_loss = 100.0
181
    need_profiling = args.profile
182

183
    for epoch in range(args.epochs):
184
185
        n_items = 0
        epoch_runtime = 0.0
186
187

        for batch in dataloader:
188
189
            if not args.cpu:
                torch.cuda.synchronize(rank)
190
            batch_start = time.monotonic()
191

192
            def closure(data=batch, grad_scaler=None):
193
                model.zero_grad()
194
195
196
197
198
199
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
200
                if grad_scaler is not None:
201
202
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
203
204
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
205
206
207

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
208
                else:
209
210
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
211
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
212

213
214
215
216
217
218
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
219
220
                return loss

221
222
223
224
225
226
227
228
229
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

230
            if need_profiling and not args.cpu:
231
                logging.info("Profiling the run")
232
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
233
                    with profiler.record_function("batch"):
234
                        final_loss = run_closure(closure, scaler, optimizer)
235

236
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
237
238
239
                need_profiling = False  # only profile once

            else:
240
                final_loss = run_closure(closure, scaler, optimizer)
241

242
243
244
245
246
247
248
249
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

250
251
            n_items += args.batch_size

252
253
254
255
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

256
            batch_end = time.monotonic()
257
            epoch_runtime += batch_end - batch_start
258

259
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
260
261
262
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
263
            optimizer.consolidate_state_dict()
264
            if dist.get_rank() == 0:
265
                _ = optimizer.state_dict()
266
                logging.info("... State dict collected")
267

268
        measurements.append(n_items / epoch_runtime)
269
        min_loss = min(min_loss, final_loss)
270
        if dist.get_rank() == 0:
271
272
273
274
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. "
                f"Loss {final_loss:.3f} min loss {min_loss:.3f}"
            )
275
276

    training_stop = time.monotonic()
277
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
278
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
279

280
281
282
283
    # Use min_loss to check instead of final_loss since the final_loss is a bit random.
    # If the training min_loss reaches certain number, we can be reasonably certain the
    # training process was correct.
    validate_benchmark(measurements, min_loss, args, check_regression)
284

285
286
    dist.destroy_process_group()  # type: ignore

287
288
289
290
291
292
293

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
294
    parser.add_argument("--batch_size", action="store", default=256, type=int)
295
    parser.add_argument("--check_regression", action="store_true", default=False)
296
297
298
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
299
    parser.add_argument("--gloo", action="store_true", default=False)
300
    parser.add_argument("--profile", action="store_true", default=False)
301
    parser.add_argument("--cpu", action="store_true", default=False)
302
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
303
304
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
305
306
307
    parser.add_argument(
        "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
    )
308
309

    args = parser.parse_args()
310

311
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
312
    logging.info("Benchmark arguments: %s" % args)
313

314
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
315

316
    # Download dataset once for all processes
317
    setup_cached_mnist()
318

319
    # Benchmark the different configurations, via multiple processes
320
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
321
        logging.info("\n*** Benchmark vanilla optimizer")
322
        mp.spawn(
323
324
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
325
326
327
328
            nprocs=args.world_size,
            join=True,
        )

329
330
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
331
        mp.spawn(
332
333
334
335
            train,
            args=(args, BACKEND, OptimType.oss_ddp, args.check_regression),
            nprocs=args.world_size,
            join=True,  # type: ignore
336
337
        )

338
339
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
340
        mp.spawn(
341
            train,  # type: ignore
342
343
344
345
346
347
            args=(
                args,
                BACKEND,
                OptimType.oss_sharded_ddp,
                args.check_regression,
            ),
348
349
350
            nprocs=args.world_size,
            join=True,
        )