oss.py 11.9 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
25
from torchvision.transforms import ToTensor

26
27
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
28

29
OPTIM = torch.optim.RMSprop
30
TEMPDIR = tempfile.gettempdir()
31
32


33
def dist_init(rank, world_size, backend):
34
    logging.info(f"Using backend: {backend}")
35
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
36
37


38
def get_problem(rank, world_size, batch_size, device, model_name: str):
39
    # Select the desired model on the fly
40
    logging.info(f"Using {model_name} for benchmarking")
41
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
42

43
    # Data setup, duplicate the grey channels to get pseudo color
44
45
    def collate(inputs: List[Any]):
        return {
46
47
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
48
49
        }

50
51
52
53
54
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

55
    loss_fn = nn.CrossEntropyLoss()
56
57
58
    return model, dataloader, loss_fn


59
60
class OptimType(str, Enum):
    vanilla = "pytorch"
61
62
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
63
64
65
    everyone = "everyone"


66
67
def train(
    rank: int,
68
    args: argparse.Namespace,
69
    backend: str = "gloo",
70
    optim_type: OptimType = OptimType.vanilla,
71
72
    check_regression: bool = True,
):
73
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
74

75
    # DDP
76
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
77
78

    # Setup
79
80
81
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
82
83
84
85
86
87
88
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

89
    device = torch.device("cpu") if args.cpu else torch.device(rank)
90
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
91

92
    # Shard the optimizer
93
    optimizer: Optional[torch.optim.Optimizer] = None
94
    model = cast(nn.Module, model)
95

96
97
98
    if optim_type == OptimType.oss_sharded_ddp:
        model = ShardedDDP(
            model,
99
100
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
101
            world_size=args.world_size,
102
            broadcast_buffers=True,
103
        )
104
105
        optimizer = model.sharded_optimizer

106
    else:
107
108
109
110
111
        if args.cpu:
            device_ids = None
        else:
            device_ids = [rank]
        model = DDP(model, device_ids=device_ids, find_unused_parameters=False)  # type: ignore
112
113
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
114
            if optim_type == OptimType.oss_ddp
115
116
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
117
    optimizer = cast(torch.optim.Optimizer, optimizer)
118
119

    # Reset the memory use counter
120
121
122
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
123

124
    # Standard training loop
125
126
127
128
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
129
    final_loss: Optional[float] = -1.0
130
    need_profiling = args.profile
131

132
    for epoch in range(args.epochs):
133
134
        n_items = 0
        epoch_runtime = 0.0
135
136

        for batch in dataloader:
137
            batch__start = time.monotonic()
138

139
            def closure(data=batch):
140
                model.zero_grad()
141
142
143
144
145
146
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
147
148
149
                if not args.cpu and args.amp:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
150
151
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])
152
                else:
153
154
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
155

156
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
157

158
159
                if optim_type == OptimType.oss_sharded_ddp:
                    model.reduce()
160

161
162
163
164
165
166
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
167
168
                return loss

169
            if need_profiling and not args.cpu:
170
                logging.info("Profiling the run")
171
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
172
173
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
174
                        logging.info("profiling done")
175
176
177
178
179
180
181
182

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
183

184
185
186
187
188
189
190
191
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

192
193
194
195
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
196

197
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
198
199
200
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
201
            optimizer.consolidate_state_dict()
202
            if dist.get_rank() == 0:
203
                _ = optimizer.state_dict()
204
                logging.info("... State dict collected")
205

206
        measurements.append(n_items / epoch_runtime)
207
        if dist.get_rank() == 0:
208
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
209

210
    max_memory = -1.0
211
212
    if not args.cpu:
        torch.cuda.synchronize(rank)
213
214
215
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

216
    training_stop = time.monotonic()
217
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
218
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
219

220
221
222
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
223
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
224
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
225

226
    if check_regression and dist.get_rank() == 0:
227
228
229
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
230

231
        logging.info("[Regression Test] VALID")
232

233
234
    dist.destroy_process_group()  # type: ignore

235
236
237
238
239
240
241

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
242
    parser.add_argument("--batch_size", action="store", default=256, type=int)
243
    parser.add_argument("--check_regression", action="store_true", default=False)
244
245
246
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
247
248
249
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
250
    parser.add_argument("--gloo", action="store_true", default=False)
251
    parser.add_argument("--profile", action="store_true", default=False)
252
    parser.add_argument("--cpu", action="store_true", default=False)
253
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
254
255
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
256
257

    args = parser.parse_args()
258

259
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
260
    logging.info("Benchmark arguments: %s" % args)
261

262
    BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
263

264
265
266
267
268
269
270
271
272
273
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

274
            logging.warning("Failed loading dataset: %s " % e)
275
276
277
278
279
280
281
282
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

283
    # Benchmark the different configurations, via multiple processes
284
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
285
        logging.info("\n*** Benchmark vanilla optimizer")
286
287
        mp.spawn(
            train,
288
            args=(args, BACKEND, OptimType.vanilla, False,),  # no regression check
289
290
291
292
            nprocs=args.world_size,
            join=True,
        )

293
294
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
295
        mp.spawn(
296
            train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
297
298
        )

299
300
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
301
302
        mp.spawn(
            train,
303
304
            args=(
                args,
305
                BACKEND,
306
307
308
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
309
310
311
            nprocs=args.world_size,
            join=True,
        )