oss.py 8.38 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
7
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
from typing import Any, List, Optional, cast
9

10
import numpy as np
11
12
13
14
15
16
17
18
19
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

20
from fairscale.nn.data_parallel import ShardedDataParallel
21
22
from fairscale.optim.oss import OSS

23
OPTIM = torch.optim.RMSprop
24
25


26
27
28
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
29
30


31
def get_problem(rank, data_size, batch_size):
32
33
34
35
36
37
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
38
39
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
40
41
42
43
44
45
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()
46
47
48
49
50
51
52
53
54
    return model, dataloader, loss_fn


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
55
    backend: str = "gloo",
56
    use_oss: bool = True,
57
    use_sdp: bool = False,
58
59
60
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
61
    reference_loss: float = -1.0,
62
):
63
    assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
64
    # DDP
65
    dist_init(rank=rank, world_size=world_size, backend=backend)
66
67

    # Setup
68
69
70
71
72
73
74
75
76
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

77
78
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

79
    # Shard the optimizer
80
81
82
83
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
84
85
86
87
88
            module=model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
            world_size=world_size,
            broadcast_buffers=False,
89
90
91
92
93
94
95
96
97
98
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
99
100
101

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
102
103
104
105
106
107
108

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
109
    final_loss: Optional[float] = -1.0
110
111
112
113
114
115
116
117
118
119
120
121

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
122
123
124

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

125
126
127
                if use_sdp:
                    ddp.reduce()  # Send the gradients to the appropriate shards

128
129
                return loss

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
130
            final_loss = optimizer.step(closure)
131
132

        epoch_end = time.monotonic()
133
134
135
136
137

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
138
            optimizer.consolidate_state_dict()
139
            if dist.get_rank() == 0:
140
                _ = optimizer.state_dict()
141
142
                print("... State dict collected")

143
        measurements.append(data_size / (epoch_end - epoch_start))
144
        if dist.get_rank() == 0:
145
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
146
147
148
149
150
151
152
153
154

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

155
156
157
158
159
160
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

161
    if use_oss and check_regression and dist.get_rank() == 0:
162
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
163
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
164
165
        assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"

166
167
        print("[Regression Test] VALID")

168
169
    dist.destroy_process_group()  # type: ignore

170
171
172

if __name__ == "__main__":

173
174
175
176
177
178
    class OptimType(str, Enum):
        vanilla = "pytorch"
        oss = "oss"
        oss_sdp = "oss_sdp"
        everyone = "everyone"

179
180
181
182
183
184
185
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
186
    parser.add_argument("--check_regression", action="store_true", default=False)
187
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
188
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
189
190
191
192
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
193
    parser.add_argument("--gloo", action="store_true", default=False)
194
195

    args = parser.parse_args()
196
    print(f"Benchmark arguments: {args}")
197

198
    backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
199
200

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
201
202
203
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
204
205
206
207
208
209
210
211
212
213
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                False,  # OSS
                False,  # SDP
                False,  # no regression check
            ),
214
215
216
217
            nprocs=args.world_size,
            join=True,
        )

218
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
219
220
221
222
223
224
225
226
        print("\nBenchmark OSS")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
227
                backend,
228
229
                True,  # OSS
                False,  # SDP
230
231
232
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
                args.reference_loss,
            ),
            nprocs=args.world_size,
            join=True,
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
        print("\nBenchmark OSS DDP")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                True,  # OSS
                True,  # SDP
251
252
253
254
                args.check_regression,
                -1,  # Not checking SDP for speed regression for now, still slower than OSS
                args.reference_memory,
                args.reference_loss,
255
256
257
258
            ),
            nprocs=args.world_size,
            join=True,
        )