oss.py 8.07 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
7
import math
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
from typing import Any, List, Optional, cast
9

10
import numpy as np
11
12
13
14
15
16
17
18
19
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

20
from fairscale.nn.data_parallel import ShardedDataParallel
21
22
from fairscale.optim.oss import OSS

23
OPTIM = torch.optim.RMSprop
24
25


26
27
28
def dist_init(rank, world_size, backend):
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
29
30


31
def get_problem(rank, data_size, batch_size):
32
33
34
35
36
37
    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
38
39
            "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
40
41
42
43
44
45
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()
46
47
48
49
50
51
52
53
54
    return model, dataloader, loss_fn


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
55
    backend: str = "gloo",
56
    use_oss: bool = True,
57
    use_sdp: bool = False,
58
59
60
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
61
    reference_loss: float = -1.0,
62
):
63
    assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
64
    # DDP
65
    dist_init(rank, world_size, backend)
66
67

    # Setup
68
69
70
71
72
73
74
75
76
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

77
78
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

79
    # Shard the optimizer
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
            module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size,
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
95
96
97

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)
98
99
100
101
102
103
104

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
105
    final_loss: Optional[float] = -1.0
106
107
108
109
110
111
112
113
114
115
116
117

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
118
119
120

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

121
122
123
                if use_sdp:
                    ddp.reduce()  # Send the gradients to the appropriate shards

124
125
                return loss

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
126
            final_loss = optimizer.step(closure)
127
128

        epoch_end = time.monotonic()
129
130
131
132
133

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
134
            optimizer.consolidate_state_dict()
135
            if dist.get_rank() == 0:
136
                _ = optimizer.state_dict()
137
138
                print("... State dict collected")

139
        measurements.append(data_size / (epoch_end - epoch_start))
140
        if dist.get_rank() == 0:
141
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
142
143
144
145
146
147
148
149
150

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

151
152
153
154
155
156
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

157
    if use_oss and check_regression and dist.get_rank() == 0:
158
        assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
159
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
160
161
        assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"

162
163
164
165
166
        print("[Regression Test] VALID")


if __name__ == "__main__":

167
168
169
170
171
172
    class OptimType(str, Enum):
        vanilla = "pytorch"
        oss = "oss"
        oss_sdp = "oss_sdp"
        everyone = "everyone"

173
174
175
176
177
178
179
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
180
    parser.add_argument("--check_regression", action="store_true", default=False)
181
    parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
182
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
183
184
185
186
    parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
187
    parser.add_argument("--gloo", action="store_true", default=False)
188
189

    args = parser.parse_args()
190
    print(f"Benchmark arguments: {args}")
191

192
    backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
193
194

    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
195
196
197
        print("\nBenchmark vanilla optimizer")
        mp.spawn(
            train,
198
199
200
201
202
203
204
205
206
207
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                False,  # OSS
                False,  # SDP
                False,  # no regression check
            ),
208
209
210
211
            nprocs=args.world_size,
            join=True,
        )

212
    if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
213
214
215
216
217
218
219
220
        print("\nBenchmark OSS")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
221
                backend,
222
223
                True,  # OSS
                False,  # SDP
224
225
226
                args.check_regression,
                args.reference_speed,
                args.reference_memory,
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                args.reference_loss,
            ),
            nprocs=args.world_size,
            join=True,
        )

    if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
        print("\nBenchmark OSS DDP")
        mp.spawn(
            train,
            args=(
                args.world_size,
                args.epochs,
                args.batch_size,
                args.data_size,
                backend,
                True,  # OSS
                True,  # SDP
                False,  # no regression check
246
247
248
249
            ),
            nprocs=args.world_size,
            join=True,
        )