oss.py 13.6 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
5
6

import argparse
7
from enum import Enum
8
import importlib
9
10
import logging
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
from golden_configs import oss_mnist
15
import numpy as np
16
import torch
17
import torch.autograd.profiler as profiler
18
from torch.cuda.amp import GradScaler as TorchGradScaler
19
20
21
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
22
from torch.nn.parallel import DistributedDataParallel as DDP
23
24
25
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
26
from torchvision.transforms import Compose, Resize, ToTensor
27

28
from benchmarks.datasets.mnist import setup_cached_mnist
29
30
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
31
from fairscale.optim.grad_scaler import ShardedGradScaler
32

33
TEMPDIR = tempfile.gettempdir()
34
35


36
def dist_init(rank, world_size, backend):
37
    logging.info(f"Using backend: {backend}")
38
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
39
40


41
def get_problem(rank, world_size, batch_size, device, model_name: str):
42
    # Select the desired model on the fly
43
    logging.info(f"Using {model_name} for benchmarking")
44
45
46
47
48

    try:
        model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
    except AttributeError:
        model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
49

50
    # Data setup, duplicate the grey channels to get pseudo color
51
52
    def collate(inputs: List[Any]):
        return {
53
54
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
55
56
        }

57
58
59
60
61
62
63
64
65
66
    # Transforms
    transforms = []
    if model_name.startswith("vit"):
        # ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
        pic_size = int(model_name.split("_")[-1])
        transforms.append(Resize(pic_size))

    transforms.append(ToTensor())

    dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
67
68
69
70
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

71
    loss_fn = nn.CrossEntropyLoss()
72
73
74
    return model, dataloader, loss_fn


75
76
class OptimType(str, Enum):
    vanilla = "pytorch"
77
78
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
79
80
81
    everyone = "everyone"


82
def validate_benchmark(measurements, final_loss, args, check_regression):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    """Validate the measurments against the golden benchmark config."""

    golden_data = oss_mnist.get_golden_real_stats()

    max_memory = -1.0
    rank = dist.get_rank()
    if not args.cpu:
        # TODO(anj-s): Check if we need to synchronize before we caculate total training time.
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{rank}] : Peak memory {max_memory:.1f}MiB")

    measurements.sort()
    median = measurements[len(measurements) // 2]
    # Compute the median and median of absolute differences img per second.
    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    # TODO(anj-s): Add a debug flag to perform the above calculation only when required.
    logging.info(f"[{rank}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and rank == 0:
106
107
        assert median + 8.0 * mad > golden_data["reference_speed"], (
            f"Speed regression detected: " f"{median + 8.0 * mad} vs.  {golden_data['reference_speed']}"
108
109
110
111
        )
        assert max_memory < 1.05 * golden_data["reference_memory"], (
            f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
        )
112
        assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-2, (
113
114
            f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
        )
115
116
117
        logging.info("[Regression Test] VALID")


118
119
def train(
    rank: int,
120
    args: argparse.Namespace,
121
    backend: str = "gloo",
122
    optim_type: OptimType = OptimType.vanilla,
123
124
    check_regression: bool = True,
):
125
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
126

127
128
129
130
    use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

131
    # DDP
132
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
133
134

    # Setup
135
136
137
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
138
139
140
141
142
143
144
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

145
    device = torch.device("cpu") if args.cpu else torch.device(rank)
146
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
147

148
    # Shard the optimizer
149
    optimizer: Optional[torch.optim.Optimizer] = None
150
    model = cast(nn.Module, model)
151
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
152

153
    if optim_type == OptimType.oss_sharded_ddp:
154
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
155
156
        # Single node run typically, no need for reduce buckets
        model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
157
    else:
158
        device_ids = None if args.cpu else [rank]
159
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
160
161
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
162
            if optim_type == OptimType.oss_ddp
163
164
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
165
    optimizer = cast(torch.optim.Optimizer, optimizer)
166
167

    # Reset the memory use counter
168
    if not args.cpu:
169
        torch.cuda.empty_cache()
170
171
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
172

173
    # Standard training loop
174
175
176
177
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
178
    final_loss: Optional[float] = -1.0
179
    need_profiling = args.profile
180

181
    for epoch in range(args.epochs):
182
183
        n_items = 0
        epoch_runtime = 0.0
184
185

        for batch in dataloader:
186
187
            if not args.cpu:
                torch.cuda.synchronize(rank)
188
            batch_start = time.monotonic()
189

190
            def closure(data=batch, grad_scaler=None):
191
                model.zero_grad()
192
193
194
195
196
197
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
198
                if grad_scaler is not None:
199
200
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
201
202
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
203
204
205

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
206
                else:
207
208
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
209
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
210

211
212
213
214
215
216
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
217
218
                return loss

219
220
221
222
223
224
225
226
227
            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

228
            if need_profiling and not args.cpu:
229
                logging.info("Profiling the run")
230
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
231
                    with profiler.record_function("batch"):
232
                        final_loss = run_closure(closure, scaler, optimizer)
233

234
                prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
235
236
237
                need_profiling = False  # only profile once

            else:
238
                final_loss = run_closure(closure, scaler, optimizer)
239

240
241
242
243
244
245
246
247
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

248
249
            n_items += args.batch_size

250
251
252
253
            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

254
            batch_end = time.monotonic()
255
            epoch_runtime += batch_end - batch_start
256

257
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
258
259
260
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
261
            optimizer.consolidate_state_dict()
262
            if dist.get_rank() == 0:
263
                _ = optimizer.state_dict()
264
                logging.info("... State dict collected")
265

266
        measurements.append(n_items / epoch_runtime)
267
        if dist.get_rank() == 0:
268
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
269
270

    training_stop = time.monotonic()
271
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
272
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
273

274
    validate_benchmark(measurements, final_loss, args, check_regression)
275

276
277
    dist.destroy_process_group()  # type: ignore

278
279
280
281
282
283
284

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
285
    parser.add_argument("--batch_size", action="store", default=256, type=int)
286
    parser.add_argument("--check_regression", action="store_true", default=False)
287
288
289
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
290
    parser.add_argument("--gloo", action="store_true", default=False)
291
    parser.add_argument("--profile", action="store_true", default=False)
292
    parser.add_argument("--cpu", action="store_true", default=False)
293
    parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
294
295
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
296
297
298
    parser.add_argument(
        "--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
    )
299
300

    args = parser.parse_args()
301

302
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
303
    logging.info("Benchmark arguments: %s" % args)
304

305
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
306

307
    # Download dataset once for all processes
308
    setup_cached_mnist()
309

310
    # Benchmark the different configurations, via multiple processes
311
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
312
        logging.info("\n*** Benchmark vanilla optimizer")
313
        mp.spawn(
314
315
            train,  # type: ignore
            args=(args, BACKEND, OptimType.vanilla, False),  # no regression check
316
317
318
319
            nprocs=args.world_size,
            join=True,
        )

320
321
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
322
        mp.spawn(
323
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,  # type: ignore
324
325
        )

326
327
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
328
        mp.spawn(
329
            train,  # type: ignore
330
            args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
331
332
333
            nprocs=args.world_size,
            join=True,
        )