oss.py 5.26 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
import math
import os
import time
8
from typing import Any, List, Union, cast
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
22
OPTIM = torch.optim.RMSprop
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


def dist_init(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)


def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
40
    reference_memory: float = -1.0,
41
):
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    # DDP
    dist_init(rank, world_size)

    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
            "inputs": torch.stack([i[0] for i in inputs]).to(rank),
            "label": torch.stack([i[1] for i in inputs]).to(rank),
        }

    dataloader = DataLoader(
        dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
    )
    loss_fn = nn.CrossEntropyLoss()

61
62
63
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

64
    # Shard the optimizer
65
66
67
    optimizer: Union[OSS, OPTIM] = OSS(
        params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
    ) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()
93
94
95
96
97
98
99
100
101
102

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            # optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                # _ = optimizer.state_dict()
                print("... State dict collected")

103
        measurements.append(data_size / (epoch_end - epoch_start))
104
105
        if dist.get_rank() == 0:
            print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
106
107
108
109
110
111
112
113
114

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

115
116
117
118
119
120
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

121
    if use_oss and check_regression and dist.get_rank() == 0:
122
123
        assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        print("[Regression Test] VALID")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
    parser.add_argument("--batch_size", action="store", default=32, type=int)
    parser.add_argument("--data_size", action="store", default=512, type=int)
    parser.add_argument("--check_regression", action="store", default=True, type=bool)
    parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
138
    parser.add_argument("--reference_memory", action="store", default=4475, type=float)
139
140
141

    args = parser.parse_args()

142
    print("\nBenchmark vanilla optimizer")
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    mp.spawn(
        train,
        args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
        nprocs=args.world_size,
        join=True,
    )

    print("\nBenchmark OSS")
    mp.spawn(
        train,
        args=(
            args.world_size,
            args.epochs,
            args.batch_size,
            args.data_size,
            True,
            args.check_regression,
            args.reference_speed,
161
            args.reference_memory,
162
163
164
165
        ),
        nprocs=args.world_size,
        join=True,
    )