oss.py 12.8 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
21
from torch.nn.parallel import DistributedDataParallel as DDP
22
23
24
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
25
26
from torchvision.transforms import ToTensor

27
28
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
29
from fairscale.optim.grad_scaler import ShardedGradScaler
30

31
OPTIM = torch.optim.RMSprop
32
TEMPDIR = tempfile.gettempdir()
33
34


35
def dist_init(rank, world_size, backend):
36
    logging.info(f"Using backend: {backend}")
37
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
38
39


40
def get_problem(rank, world_size, batch_size, device, model_name: str):
41
    # Select the desired model on the fly
42
    logging.info(f"Using {model_name} for benchmarking")
43
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
44

45
    # Data setup, duplicate the grey channels to get pseudo color
46
47
    def collate(inputs: List[Any]):
        return {
48
49
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
50
51
        }

52
53
54
55
56
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

57
    loss_fn = nn.CrossEntropyLoss()
58
59
60
    return model, dataloader, loss_fn


61
62
class OptimType(str, Enum):
    vanilla = "pytorch"
63
64
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
65
66
67
    everyone = "everyone"


68
69
def train(
    rank: int,
70
    args: argparse.Namespace,
71
    backend: str = "gloo",
72
    optim_type: OptimType = OptimType.vanilla,
73
74
    check_regression: bool = True,
):
75
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
76

77
    # DDP
78
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
79
80

    # Setup
81
82
83
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
84
85
86
87
88
89
90
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

91
    device = torch.device("cpu") if args.cpu else torch.device(rank)
92
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
93

94
    # Shard the optimizer
95
    optimizer: Optional[torch.optim.Optimizer] = None
96
    model = cast(nn.Module, model)
97
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
98

99
100
101
    if optim_type == OptimType.oss_sharded_ddp:
        model = ShardedDDP(
            model,
102
103
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
104
            world_size=args.world_size,
105
            broadcast_buffers=True,
106
        )
107
        optimizer = model.sharded_optimizer
108
    else:
109
110
111
112
113
        if args.cpu:
            device_ids = None
        else:
            device_ids = [rank]
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
114
115
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
116
            if optim_type == OptimType.oss_ddp
117
118
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
119
    optimizer = cast(torch.optim.Optimizer, optimizer)
120
121

    # Reset the memory use counter
122
123
124
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
125

126
    # Standard training loop
127
128
129
130
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
131
    final_loss: Optional[float] = -1.0
132
    need_profiling = args.profile
133

134
    for epoch in range(args.epochs):
135
136
        n_items = 0
        epoch_runtime = 0.0
137
138

        for batch in dataloader:
139
            batch__start = time.monotonic()
140

141
            def closure(data=batch, grad_scaler=None):
142
                model.zero_grad()
143
144
145
146
147
148
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
149
                if grad_scaler is not None:
150
151
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
152
153
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
154
155
156

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
157
                else:
158
159
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
160
                    loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
161

162
163
                if optim_type == OptimType.oss_sharded_ddp:
                    model.reduce()
164

165
166
167
168
169
170
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
171
172
                return loss

173
            if need_profiling and not args.cpu:
174
                logging.info("Profiling the run")
175
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
176
                    with profiler.record_function("batch"):
177
178
179
180
181
182
                        if scaler is not None:
                            final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)
183

184
                        prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
185
186
187
188

                need_profiling = False  # only profile once

            else:
189
190
191
192
193
194
                if scaler is not None:
                    final_loss = closure(grad_scaler=scaler)  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)
195

196
197
198
199
200
201
202
203
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

204
205
206
207
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
208

209
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
210
211
212
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
213
            optimizer.consolidate_state_dict()
214
            if dist.get_rank() == 0:
215
                _ = optimizer.state_dict()
216
                logging.info("... State dict collected")
217

218
        measurements.append(n_items / epoch_runtime)
219
        if dist.get_rank() == 0:
220
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
221

222
    max_memory = -1.0
223
224
    if not args.cpu:
        torch.cuda.synchronize(rank)
225
226
227
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

228
    training_stop = time.monotonic()
229
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
230
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
231

232
233
234
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
235
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
236
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
237

238
    if check_regression and dist.get_rank() == 0:
239
240
241
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
242

243
        logging.info("[Regression Test] VALID")
244

245
246
    dist.destroy_process_group()  # type: ignore

247
248
249
250
251
252
253

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
254
    parser.add_argument("--batch_size", action="store", default=256, type=int)
255
    parser.add_argument("--check_regression", action="store_true", default=False)
256
257
258
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
259
260
261
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
262
    parser.add_argument("--gloo", action="store_true", default=False)
263
    parser.add_argument("--profile", action="store_true", default=False)
264
    parser.add_argument("--cpu", action="store_true", default=False)
265
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
266
267
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
268
269

    args = parser.parse_args()
270

271
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
272
    logging.info("Benchmark arguments: %s" % args)
273

274
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
275

276
277
278
279
280
281
282
283
284
285
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

286
            logging.warning("Failed loading dataset: %s " % e)
287
288
289
290
291
292
293
294
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

295
    # Benchmark the different configurations, via multiple processes
296
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
297
        logging.info("\n*** Benchmark vanilla optimizer")
298
299
        mp.spawn(
            train,
300
            args=(args, BACKEND, OptimType.vanilla, False,),  # no regression check
301
302
303
304
            nprocs=args.world_size,
            join=True,
        )

305
306
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
307
        mp.spawn(
308
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
309
310
        )

311
312
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
313
314
        mp.spawn(
            train,
315
316
            args=(
                args,
317
                BACKEND,
318
319
320
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
321
322
323
            nprocs=args.world_size,
            join=True,
        )