oss.py 9.84 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
25
from torchvision.transforms import ToTensor

26
from fairscale.nn.data_parallel import ShardedDataParallel
27
28
from fairscale.optim.oss import OSS

29
OPTIM = torch.optim.RMSprop
30
TEMPDIR = tempfile.gettempdir()
31
32


33
def dist_init(rank, world_size, backend):
34
    logging.info(f"Using backend: {backend}")
35
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
36
37


38
def get_problem(rank, world_size, batch_size, device, model_name: str):
39
    # Select the desired model on the fly
40
    logging.info(f"Using {model_name} for benchmarking")
41
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
42

43
    # Data setup, duplicate the grey channels to get pseudo color
44
45
    def collate(inputs: List[Any]):
        return {
46
47
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
48
49
        }

50
51
52
53
54
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

55
    loss_fn = nn.CrossEntropyLoss()
56
57
58
    return model, dataloader, loss_fn


59
60
61
62
63
64
65
class OptimType(str, Enum):
    vanilla = "pytorch"
    oss = "oss"
    oss_sdp = "oss_sdp"
    everyone = "everyone"


66
67
def train(
    rank: int,
68
    args: argparse.Namespace,
69
    backend: str = "gloo",
70
    optim_type: OptimType = OptimType.vanilla,
71
72
    check_regression: bool = True,
):
73
74
    logging.basicConfig(level=logging.INFO)

75
    # DDP
76
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
77
78

    # Setup
79
80
81
82
83
84
85
86
87
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

88
    device = torch.device("cpu") if args.cpu else torch.device(rank)
89
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
90

91
    # Shard the optimizer
92
    optimizer: Optional[torch.optim.Optimizer] = None
93
    model = cast(nn.Module, model)
94

95
    if optim_type == OptimType.oss_sdp:
96
        ddp = ShardedDataParallel(
97
98
99
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
100
            world_size=args.world_size,
101
            broadcast_buffers=True,
102
103
104
105
106
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
107
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # type: ignore
108
109
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
110
            if optim_type == OptimType.oss
111
112
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
113
114

    # Reset the memory use counter
115
116
117
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
118

119
    # Standard training loop
120
121
122
123
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
124
    final_loss: Optional[float] = -1.0
125
    need_profiling = args.profile
126

127
    for epoch in range(args.epochs):
128
129
        n_items = 0
        epoch_runtime = 0.0
130
131

        for batch in dataloader:
132
            batch__start = time.monotonic()
133
134
135
136
137
138

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
139

140
                if optim_type == OptimType.oss_sdp:
141
142
                    ddp.reduce()  # Send the gradients to the appropriate shards

143
144
                return loss

145
            if need_profiling and not args.cpu:
146
                logging.info("Profiling the run")
147
148
149
                with profiler.profile(use_cuda=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
150
                        logging.info("profiling done, final loss ", cast(float, final_loss))
151
152
153
154
155
156
157
158

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
159

160
161
162
163
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
164

165
        if optim_type == OptimType.oss:
166
167
168
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
169
            optimizer.consolidate_state_dict()
170
            if dist.get_rank() == 0:
171
                _ = optimizer.state_dict()
172
                logging.info("... State dict collected")
173

174
        measurements.append(n_items / epoch_runtime)
175
        if dist.get_rank() == 0:
176
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
177

178
179
    if not args.cpu:
        torch.cuda.synchronize(rank)
180
    training_stop = time.monotonic()
181
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
182
183
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

184
185
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
186

187
188
189
190
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
191
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
192

193
    if check_regression and dist.get_rank() == 0:
194
195
196
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
197

198
        logging.info("[Regression Test] VALID")
199

200
201
    dist.destroy_process_group()  # type: ignore

202
203
204
205
206
207
208

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
209
    parser.add_argument("--batch_size", action="store", default=256, type=int)
210
    parser.add_argument("--check_regression", action="store_true", default=False)
211
212
213
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
214
215
216
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
217
    parser.add_argument("--gloo", action="store_true", default=False)
218
    parser.add_argument("--profile", action="store_true", default=False)
219
    parser.add_argument("--cpu", action="store_true", default=False)
220
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
221
222

    args = parser.parse_args()
223
224
225

    logging.basicConfig(level=logging.INFO)
    logging.info(f"Benchmark arguments: {args}")
226

227
    backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

            logging.warning("Failed loading dataset: ", e)
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

248
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
249
        logging.info("*** Benchmark vanilla optimizer")
250
251
        mp.spawn(
            train,
252
            args=(args, backend, OptimType.vanilla, False,),  # no regression check
253
254
255
256
            nprocs=args.world_size,
            join=True,
        )

257
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
258
        logging.info("*** Benchmark OSS with DDP")
259
        mp.spawn(
260
            train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
261
262
263
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
264
        logging.info("*** Benchmark OSS with SDP")
265
266
        mp.spawn(
            train,
267
            args=(args, backend, OptimType.oss_sdp, False,),  # FIXME: @lefaudeux - SDP should give the same results
268
269
270
            nprocs=args.world_size,
            join=True,
        )