oss.py 11.5 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
25
from torchvision.transforms import ToTensor

26
27
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
28

29
OPTIM = torch.optim.RMSprop
30
TEMPDIR = tempfile.gettempdir()
31
32


33
def dist_init(rank, world_size, backend):
34
    logging.info(f"Using backend: {backend}")
35
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
36
37


38
def get_problem(rank, world_size, batch_size, device, model_name: str):
39
    # Select the desired model on the fly
40
    logging.info(f"Using {model_name} for benchmarking")
41
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
42

43
    # Data setup, duplicate the grey channels to get pseudo color
44
45
    def collate(inputs: List[Any]):
        return {
46
47
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
48
49
        }

50
51
52
53
54
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

55
    loss_fn = nn.CrossEntropyLoss()
56
57
58
    return model, dataloader, loss_fn


59
60
class OptimType(str, Enum):
    vanilla = "pytorch"
61
62
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
63
64
65
    everyone = "everyone"


66
67
def train(
    rank: int,
68
    args: argparse.Namespace,
69
    backend: str = "gloo",
70
    optim_type: OptimType = OptimType.vanilla,
71
72
    check_regression: bool = True,
):
73
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
74

75
    # DDP
76
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
77
78

    # Setup
79
80
81
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
82
83
84
85
86
87
88
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

89
    device = torch.device("cpu") if args.cpu else torch.device(rank)
90
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
91

92
    # Shard the optimizer
93
    optimizer: Optional[torch.optim.Optimizer] = None
94
    model = cast(nn.Module, model)
95

96
97
98
    if optim_type == OptimType.oss_sharded_ddp:
        model = ShardedDDP(
            model,
99
100
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
101
            world_size=args.world_size,
102
            broadcast_buffers=True,
103
        )
104
105
        optimizer = model.sharded_optimizer

106
    else:
107
        model = DDP(model, device_ids=[rank], find_unused_parameters=False)  # type: ignore
108
109
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
110
            if optim_type == OptimType.oss_ddp
111
112
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
113
    optimizer = cast(torch.optim.Optimizer, optimizer)
114
115

    # Reset the memory use counter
116
117
118
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
119

120
    # Standard training loop
121
122
123
124
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
125
    final_loss: Optional[float] = -1.0
126
    need_profiling = args.profile
127

128
    for epoch in range(args.epochs):
129
130
        n_items = 0
        epoch_runtime = 0.0
131
132

        for batch in dataloader:
133
            batch__start = time.monotonic()
134
135
136

            def closure():
                model.zero_grad()
137
138
139
140
141
142
143
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )

144
145
146
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
147

148
149
                if optim_type == OptimType.oss_sharded_ddp:
                    model.reduce()
150

151
152
153
154
155
156
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
157
158
                return loss

159
            if need_profiling and not args.cpu:
160
                logging.info("Profiling the run")
161
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
162
163
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
164
                        logging.info("profiling done")
165
166
167
168
169
170
171
172

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
173

174
175
176
177
178
179
180
181
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

182
183
184
185
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
186

187
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
188
189
190
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
191
            optimizer.consolidate_state_dict()
192
            if dist.get_rank() == 0:
193
                _ = optimizer.state_dict()
194
                logging.info("... State dict collected")
195

196
        measurements.append(n_items / epoch_runtime)
197
        if dist.get_rank() == 0:
198
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
199

200
    max_memory = -1.0
201
202
    if not args.cpu:
        torch.cuda.synchronize(rank)
203
204
205
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

206
    training_stop = time.monotonic()
207
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
208
209
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

210
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
211
    logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
212

213
214
215
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
216
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
217
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
218

219
    if check_regression and dist.get_rank() == 0:
220
221
222
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
223

224
        logging.info("[Regression Test] VALID")
225

226
227
    dist.destroy_process_group()  # type: ignore

228
229
230
231
232
233
234

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
235
    parser.add_argument("--batch_size", action="store", default=256, type=int)
236
    parser.add_argument("--check_regression", action="store_true", default=False)
237
238
239
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
240
241
242
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
243
    parser.add_argument("--gloo", action="store_true", default=False)
244
    parser.add_argument("--profile", action="store_true", default=False)
245
    parser.add_argument("--cpu", action="store_true", default=False)
246
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
247
    parser.add_argument("--debug", action="store_true", default=False)
248
249

    args = parser.parse_args()
250

251
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
252
    logging.info(f"Benchmark arguments: {args}")
253

254
    backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

            logging.warning("Failed loading dataset: ", e)
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

275
    # Benchmark the different configurations, via multiple processes
276
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
277
        logging.info("\n*** Benchmark vanilla optimizer")
278
279
        mp.spawn(
            train,
280
            args=(args, backend, OptimType.vanilla, False,),  # no regression check
281
282
283
284
            nprocs=args.world_size,
            join=True,
        )

285
286
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
287
        mp.spawn(
288
            train, args=(args, backend, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
289
290
        )

291
292
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
293
294
        mp.spawn(
            train,
295
296
297
298
299
300
            args=(
                args,
                backend,
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
301
302
303
            nprocs=args.world_size,
            join=True,
        )