oss.py 8.61 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
7
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
from typing import Any, List, Optional, cast
9

10
import numpy as np
11
12
13
14
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
15
from torch.nn.parallel import DistributedDataParallel as DDP
16
17
18
19
20
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

21
from fairscale.nn.data_parallel import ShardedDataParallel
22
23
from fairscale.optim.oss import OSS

24
OPTIM = torch.optim.RMSprop
25
26


27
28
29
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
30
31


32
def get_problem(rank, data_size, batch_size):
33
34
35
36
37
38
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
39
40
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
41
42
43
        }

    dataloader = DataLoader(
44
45
46
        dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
        batch_size=batch_size,
        collate_fn=collate,
47
48
    )
    loss_fn = nn.CrossEntropyLoss()
49
50
51
52
53
54
55
56
57
    return model, dataloader, loss_fn


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
58
    backend: str = "gloo",
59
    use_oss: bool = True,
60
    use_sdp: bool = False,
61
62
63
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
64
    reference_loss: float = -1.0,
65
):
66
    assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
67
    # DDP
68
    dist_init(rank=rank, world_size=world_size, backend=backend)
69
70

    # Setup
71
72
73
74
75
76
77
78
79
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

80
81
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

82
    # Shard the optimizer
83
84
85
86
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
87
88
89
90
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
            world_size=world_size,
91
            broadcast_buffers=True,
92
93
94
95
96
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
97
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # type: ignore
98
99
100
101
102
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
103
104
105

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
106
107
108
109
110
111
112

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
113
    final_loss: Optional[float] = -1.0
114
115
116
117
118
119
120
121
122
123
124
125

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
126
127
128

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

129
130
131
                if use_sdp:
                    ddp.reduce()  # Send the gradients to the appropriate shards

132
133
                return loss

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
134
            final_loss = optimizer.step(closure)
135
136

        epoch_end = time.monotonic()
137
138
139
140
141

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
142
            optimizer.consolidate_state_dict()
143
            if dist.get_rank() == 0:
144
                _ = optimizer.state_dict()
145
146
                print("... State dict collected")

147
        measurements.append(data_size / (epoch_end - epoch_start))
148
        if dist.get_rank() == 0:
149
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
150
151
152
153
154
155
156
157
158

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

159
160
161
162
163
164
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

165
    if use_oss and check_regression and dist.get_rank() == 0:
166
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
167
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
168
169
        assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"

170
171
        print("[Regression Test] VALID")

172
173
    dist.destroy_process_group()  # type: ignore

174
175
176

if __name__ == "__main__":

177
178
179
180
181
182
    class OptimType(str, Enum):
        vanilla = "pytorch"
        oss = "oss"
        oss_sdp = "oss_sdp"
        everyone = "everyone"

183
184
185
186
187
188
189
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
190
    parser.add_argument("--check_regression", action="store_true", default=False)
191
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
192
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
193
194
195
196
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
197
    parser.add_argument("--gloo", action="store_true", default=False)
198
199

    args = parser.parse_args()
200
    print(f"Benchmark arguments: {args}")
201

202
    backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
203
204

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
205
206
207
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
208
209
210
211
212
213
214
215
216
217
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                False,  # OSS
                False,  # SDP
                False,  # no regression check
            ),
218
219
220
221
            nprocs=args.world_size,
            join=True,
        )

222
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
223
        print("\nBenchmark OSS with DDP")
224
225
226
227
228
229
230
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
231
                backend,
232
233
                True,  # OSS
                False,  # SDP
234
235
236
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
237
238
239
240
241
242
243
                args.reference_loss,
            ),
            nprocs=args.world_size,
            join=True,
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
244
        print("\nBenchmark OSS with SDP")
245
246
247
248
249
250
251
252
253
254
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                True,  # OSS
                True,  # SDP
255
                False,  # FIXME: @lefaudeux - SDP should give the same results
256
257
258
                -1,  # Not checking SDP for speed regression for now, still slower than OSS
                args.reference_memory,
                args.reference_loss,
259
260
261
262
            ),
            nprocs=args.world_size,
            join=True,
        )