oss.py 8.34 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
7
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
from typing import Any, List, Optional, cast
9

10
import numpy as np
11
import torch
12
import torch.autograd.profiler as profiler
13
14
15
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
16
from torch.nn.parallel import DistributedDataParallel as DDP
17
18
19
20
21
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

22
from fairscale.nn.data_parallel import ShardedDataParallel
23
24
from fairscale.optim.oss import OSS

25
OPTIM = torch.optim.RMSprop
26
27


28
29
30
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
31
32


33
def get_problem(rank, data_size, batch_size, device):
34
    # Standard RN101
35
    model = resnet101(pretrained=False, progress=True).to(device)
36
37
38
39

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
40
41
            "inputs": torch.stack([i[0] for i in inputs]).to(device),
            "label": torch.stack([i[1] for i in inputs]).to(device),
42
43
44
        }

    dataloader = DataLoader(
45
46
47
        dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
        batch_size=batch_size,
        collate_fn=collate,
48
49
    )
    loss_fn = nn.CrossEntropyLoss()
50
51
52
    return model, dataloader, loss_fn


53
54
55
56
57
58
59
class OptimType(str, Enum):
    vanilla = "pytorch"
    oss = "oss"
    oss_sdp = "oss_sdp"
    everyone = "everyone"


60
61
def train(
    rank: int,
62
    args: argparse.Namespace,
63
    backend: str = "gloo",
64
    optim_type: OptimType = OptimType.vanilla,
65
66
67
    check_regression: bool = True,
):
    # DDP
68
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
69
70

    # Setup
71
72
73
74
75
76
77
78
79
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

80
81
    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device)
82

83
    # Shard the optimizer
84
85
    optimizer: Optional[torch.optim.Optimizer] = None

86
    if optim_type == OptimType.oss_sdp:
87
        ddp = ShardedDataParallel(
88
89
90
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
91
            world_size=args.world_size,
92
            broadcast_buffers=True,
93
94
95
96
97
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
98
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # type: ignore
99
100
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
101
            if optim_type == OptimType.oss
102
103
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
104
105

    # Reset the memory use counter
106
107
108
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
109
110
111
112
113
114

    # Dummy training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
115
    final_loss: Optional[float] = -1.0
116
    need_profiling = args.profile
117

118
    for epoch in range(args.epochs):
119
120
121
122
123
124
125
126
127
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
128

129
                if optim_type == OptimType.oss_sdp:
130
131
                    ddp.reduce()  # Send the gradients to the appropriate shards

132
133
                return loss

134
            if need_profiling and not args.cpu:
135
136
137
138
139
140
141
142
143
144
145
146
147
                print("Profiling the run")
                with profiler.profile(use_cuda=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
                        print("profiling done, final loss ", cast(float, final_loss))

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
148
149

        epoch_end = time.monotonic()
150

151
        if optim_type == OptimType.oss:
152
153
154
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
155
            optimizer.consolidate_state_dict()
156
            if dist.get_rank() == 0:
157
                _ = optimizer.state_dict()
158
159
                print("... State dict collected")

160
        measurements.append(args.data_size / (epoch_end - epoch_start))
161
        if dist.get_rank() == 0:
162
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
163

164
165
    if not args.cpu:
        torch.cuda.synchronize(rank)
166
    training_stop = time.monotonic()
167
    img_per_sec = args.data_size / (training_stop - training_start) * args.epochs
168
169
170
171
172
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

173
174
175
176
177
178
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

179
    if check_regression and dist.get_rank() == 0:
180
181
182
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
183

184
185
        print("[Regression Test] VALID")

186
187
    dist.destroy_process_group()  # type: ignore

188
189
190
191
192
193
194
195
196

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
197
    parser.add_argument("--check_regression", action="store_true", default=False)
198
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
199
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
200
201
202
203
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
204
    parser.add_argument("--gloo", action="store_true", default=False)
205
    parser.add_argument("--profile", action="store_true", default=False)
206
    parser.add_argument("--cpu", action="store_true", default=False)
207
208

    args = parser.parse_args()
209
    print(f"Benchmark arguments: {args}")
210

211
    backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
212
213

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
214
215
216
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
217
            args=(args, backend, OptimType.vanilla, False,),  # no regression check
218
219
220
221
            nprocs=args.world_size,
            join=True,
        )

222
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
223
        print("\nBenchmark OSS with DDP")
224
        mp.spawn(
225
            train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
226
227
228
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
229
        print("\nBenchmark OSS with SDP")
230
231
        mp.spawn(
            train,
232
            args=(args, backend, OptimType.oss_sdp, False,),  # FIXME: @lefaudeux - SDP should give the same results
233
234
235
            nprocs=args.world_size,
            join=True,
        )