oss.py 8.58 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
8
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
9
from typing import Any, List, Optional, cast
10

11
import numpy as np
12
import torch
13
import torch.autograd.profiler as profiler
14
15
16
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
17
from torch.nn.parallel import DistributedDataParallel as DDP
18
19
20
21
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

22
from fairscale.nn.data_parallel import ShardedDataParallel
23
24
from fairscale.optim.oss import OSS

25
OPTIM = torch.optim.RMSprop
26
27


28
29
30
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
31
32


33
34
35
36
def get_problem(rank, data_size, batch_size, device, model_name: str):
    # Select the desired model on the fly
    print(f"Using {model_name} for benchmarking")
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
37
38
39
40

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
41
42
            "inputs": torch.stack([i[0] for i in inputs]).to(device),
            "label": torch.stack([i[1] for i in inputs]).to(device),
43
44
45
        }

    dataloader = DataLoader(
46
47
48
        dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
        batch_size=batch_size,
        collate_fn=collate,
49
50
    )
    loss_fn = nn.CrossEntropyLoss()
51
52
53
    return model, dataloader, loss_fn


54
55
56
57
58
59
60
class OptimType(str, Enum):
    vanilla = "pytorch"
    oss = "oss"
    oss_sdp = "oss_sdp"
    everyone = "everyone"


61
62
def train(
    rank: int,
63
    args: argparse.Namespace,
64
    backend: str = "gloo",
65
    optim_type: OptimType = OptimType.vanilla,
66
67
68
    check_regression: bool = True,
):
    # DDP
69
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
70
71

    # Setup
72
73
74
75
76
77
78
79
80
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

81
    device = torch.device("cpu") if args.cpu else torch.device(rank)
82
    model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device, args.torchvision_model)
83

84
    # Shard the optimizer
85
86
    optimizer: Optional[torch.optim.Optimizer] = None

87
    if optim_type == OptimType.oss_sdp:
88
        ddp = ShardedDataParallel(
89
90
91
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
92
            world_size=args.world_size,
93
            broadcast_buffers=True,
94
95
96
97
98
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
99
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # type: ignore
100
101
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
102
            if optim_type == OptimType.oss
103
104
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
105
106

    # Reset the memory use counter
107
108
109
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
110
111
112
113
114
115

    # Dummy training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
116
    final_loss: Optional[float] = -1.0
117
    need_profiling = args.profile
118

119
    for epoch in range(args.epochs):
120
121
122
123
124
125
126
127
128
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
129

130
                if optim_type == OptimType.oss_sdp:
131
132
                    ddp.reduce()  # Send the gradients to the appropriate shards

133
134
                return loss

135
            if need_profiling and not args.cpu:
136
137
138
139
140
141
142
143
144
145
146
147
148
                print("Profiling the run")
                with profiler.profile(use_cuda=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
                        print("profiling done, final loss ", cast(float, final_loss))

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
149
150

        epoch_end = time.monotonic()
151

152
        if optim_type == OptimType.oss:
153
154
155
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
156
            optimizer.consolidate_state_dict()
157
            if dist.get_rank() == 0:
158
                _ = optimizer.state_dict()
159
160
                print("... State dict collected")

161
        measurements.append(args.data_size / (epoch_end - epoch_start))
162
        if dist.get_rank() == 0:
163
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
164

165
166
    if not args.cpu:
        torch.cuda.synchronize(rank)
167
    training_stop = time.monotonic()
168
    img_per_sec = args.data_size / (training_stop - training_start) * args.epochs
169
170
171
172
173
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

174
175
176
177
178
179
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

180
    if check_regression and dist.get_rank() == 0:
181
182
183
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
184

185
186
        print("[Regression Test] VALID")

187
188
    dist.destroy_process_group()  # type: ignore

189
190
191
192
193
194
195
196
197

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
198
    parser.add_argument("--check_regression", action="store_true", default=False)
199
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
200
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
201
202
203
204
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
205
    parser.add_argument("--gloo", action="store_true", default=False)
206
    parser.add_argument("--profile", action="store_true", default=False)
207
    parser.add_argument("--cpu", action="store_true", default=False)
208
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
209
210

    args = parser.parse_args()
211
    print(f"Benchmark arguments: {args}")
212

213
    backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
214
215

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
216
217
218
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
219
            args=(args, backend, OptimType.vanilla, False,),  # no regression check
220
221
222
223
            nprocs=args.world_size,
            join=True,
        )

224
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
225
        print("\nBenchmark OSS with DDP")
226
        mp.spawn(
227
            train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
228
229
230
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
231
        print("\nBenchmark OSS with SDP")
232
233
        mp.spawn(
            train,
234
            args=(args, backend, OptimType.oss_sdp, False,),  # FIXME: @lefaudeux - SDP should give the same results
235
236
237
            nprocs=args.world_size,
            join=True,
        )