oss.py 12.5 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
21
from torch.nn.parallel import DistributedDataParallel as DDP
22
23
24
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
25
26
from torchvision.transforms import ToTensor

27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
OPTIM = torch.optim.RMSprop
32
TEMPDIR = tempfile.gettempdir()
33
34


35
def dist_init(rank, world_size, backend):
36
    logging.info(f"Using backend: {backend}")
37
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
38
39


40
def get_problem(rank, world_size, batch_size, device, model_name: str):
41
    # Select the desired model on the fly
42
    logging.info(f"Using {model_name} for benchmarking")
43
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
44

45
    # Data setup, duplicate the grey channels to get pseudo color
46
47
    def collate(inputs: List[Any]):
        return {
48
49
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
50
51
        }

52
53
54
55
56
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

57
    loss_fn = nn.CrossEntropyLoss()
58
59
60
    return model, dataloader, loss_fn


61
62
class OptimType(str, Enum):
    vanilla = "pytorch"
63
64
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
65
66
67
    everyone = "everyone"


68
69
def train(
    rank: int,
70
    args: argparse.Namespace,
71
    backend: str = "gloo",
72
    optim_type: OptimType = OptimType.vanilla,
73
74
    check_regression: bool = True,
):
75
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
76

77
    # DDP
78
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
79
80

    # Setup
81
82
83
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
84
85
86
87
88
89
90
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

91
    device = torch.device("cpu") if args.cpu else torch.device(rank)
92
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
93

94
    # Shard the optimizer
95
    optimizer: Optional[torch.optim.Optimizer] = None
96
    model = cast(nn.Module, model)
97
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
98

99
    if optim_type == OptimType.oss_sharded_ddp:
100
101
        optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        model = ShardedDDP(model, optimizer)
102
    else:
103
        device_ids = None if args.cpu else [rank]
104
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
105
106
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
107
            if optim_type == OptimType.oss_ddp
108
109
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
110
    optimizer = cast(torch.optim.Optimizer, optimizer)
111
112

    # Reset the memory use counter
113
    if not args.cpu:
114
        torch.cuda.empty_cache()
115
116
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
117

118
    # Standard training loop
119
120
121
122
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
123
    final_loss: Optional[float] = -1.0
124
    need_profiling = args.profile
125

126
    for epoch in range(args.epochs):
127
128
        n_items = 0
        epoch_runtime = 0.0
129
130

        for batch in dataloader:
131
            batch__start = time.monotonic()
132

133
            def closure(data=batch, grad_scaler=None):
134
                model.zero_grad()
135
136
137
138
139
140
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
141
                if grad_scaler is not None:
142
143
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
144
145
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
146
147
148

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
149
                else:
150
151
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
152
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
153

154
155
156
157
158
159
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
160
161
                return loss

162
            if need_profiling and not args.cpu:
163
                logging.info("Profiling the run")
164
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
165
                    with profiler.record_function("batch"):
166
167
168
169
170
171
                        if scaler is not None:
                            final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)
172

173
                        prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
174
175
176
177

                need_profiling = False  # only profile once

            else:
178
179
180
181
182
183
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)
184

185
186
187
188
189
190
191
192
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

193
194
195
196
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
197

198
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
199
200
201
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
202
            optimizer.consolidate_state_dict()
203
            if dist.get_rank() == 0:
204
                _ = optimizer.state_dict()
205
                logging.info("... State dict collected")
206

207
        measurements.append(n_items / epoch_runtime)
208
        if dist.get_rank() == 0:
209
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
210

211
    max_memory = -1.0
212
213
    if not args.cpu:
        torch.cuda.synchronize(rank)
214
215
216
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

217
    training_stop = time.monotonic()
218
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
219
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
220

221
222
223
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
224
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
225
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
226

227
    if check_regression and dist.get_rank() == 0:
228
229
230
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
231

232
        logging.info("[Regression Test] VALID")
233

234
235
    dist.destroy_process_group()  # type: ignore

236
237
238
239
240
241
242

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
243
    parser.add_argument("--batch_size", action="store", default=256, type=int)
244
    parser.add_argument("--check_regression", action="store_true", default=False)
245
246
247
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
248
249
250
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
251
    parser.add_argument("--gloo", action="store_true", default=False)
252
    parser.add_argument("--profile", action="store_true", default=False)
253
    parser.add_argument("--cpu", action="store_true", default=False)
254
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
255
256
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
257
258

    args = parser.parse_args()
259

260
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
261
    logging.info("Benchmark arguments: %s" % args)
262

263
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
264

265
266
267
268
269
270
271
272
273
274
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

275
            logging.warning("Failed loading dataset: %s " % e)
276
277
278
279
280
281
282
283
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

284
    # Benchmark the different configurations, via multiple processes
285
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
286
        logging.info("\n*** Benchmark vanilla optimizer")
287
288
        mp.spawn(
            train,
289
            args=(args, BACKEND, OptimType.vanilla, False,),  # no regression check
290
291
292
293
            nprocs=args.world_size,
            join=True,
        )

294
295
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
296
        mp.spawn(
297
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
298
299
        )

300
301
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
302
303
        mp.spawn(
            train,
304
305
            args=(
                args,
306
                BACKEND,
307
308
309
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
310
311
312
            nprocs=args.world_size,
            join=True,
        )