"src/diffusers/schedulers/scheduling_dpm_cogvideox.py" did not exist on "c812d97d5b87e4baa680a63b1d39cbe630544674"
oss.py 12 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
5
from enum import Enum
6
import importlib
7
import logging
8
import math
9
10
import shutil
import tempfile
11
import time
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
12
from typing import Any, List, Optional, cast
13

14
import numpy as np
15
import torch
16
import torch.autograd.profiler as profiler
17
18
19
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
23
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
24
25
from torchvision.transforms import ToTensor

26
27
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
28

29
OPTIM = torch.optim.RMSprop
30
TEMPDIR = tempfile.gettempdir()
31
32


33
def dist_init(rank, world_size, backend):
34
    logging.info(f"Using backend: {backend}")
35
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
36
37


38
def get_problem(rank, world_size, batch_size, device, model_name: str):
39
    # Select the desired model on the fly
40
    logging.info(f"Using {model_name} for benchmarking")
41
    model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
42

43
    # Data setup, duplicate the grey channels to get pseudo color
44
45
    def collate(inputs: List[Any]):
        return {
46
47
            "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
            "label": torch.tensor([i[1] for i in inputs]).to(device),
48
49
        }

50
51
52
53
54
    dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
    sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
    dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)

55
    loss_fn = nn.CrossEntropyLoss()
56
57
58
    return model, dataloader, loss_fn


59
60
class OptimType(str, Enum):
    vanilla = "pytorch"
61
62
    oss_ddp = "oss_ddp"
    oss_sharded_ddp = "oss_sharded_ddp"
63
64
65
    everyone = "everyone"


66
67
def train(
    rank: int,
68
    args: argparse.Namespace,
69
    backend: str = "gloo",
70
    optim_type: OptimType = OptimType.vanilla,
71
72
    check_regression: bool = True,
):
73
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
74

75
    # DDP
76
    dist_init(rank=rank, world_size=args.world_size, backend=backend)
77
78

    # Setup
79
80
81
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
82
83
84
85
86
87
88
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

89
    device = torch.device("cpu") if args.cpu else torch.device(rank)
90
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
91

92
    # Shard the optimizer
93
    optimizer: Optional[torch.optim.Optimizer] = None
94
    model = cast(nn.Module, model)
95

96
97
98
    if optim_type == OptimType.oss_sharded_ddp:
        model = ShardedDDP(
            model,
99
100
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
101
            world_size=args.world_size,
102
            broadcast_buffers=True,
103
        )
104
105
        optimizer = model.sharded_optimizer

106
    else:
107
        model = DDP(model, device_ids=[rank], find_unused_parameters=False)  # type: ignore
108
109
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
110
            if optim_type == OptimType.oss_ddp
111
112
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
113
    optimizer = cast(torch.optim.Optimizer, optimizer)
114
115

    # Reset the memory use counter
116
117
118
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)
119

120
    # Standard training loop
121
122
123
124
    training_start = time.monotonic()
    model.train()

    measurements = []
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
125
    final_loss: Optional[float] = -1.0
126
    need_profiling = args.profile
127

128
    for epoch in range(args.epochs):
129
130
        n_items = 0
        epoch_runtime = 0.0
131
132

        for batch in dataloader:
133
            batch__start = time.monotonic()
134
135
136

            def closure():
                model.zero_grad()
137
138
139
140
141
142
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
143
144
145
146
147
148
149
150
                if not args.cpu and args.amp:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(batch["inputs"])
                        loss = loss_fn(outputs, batch["label"])
                else:
                    outputs = model(batch["inputs"])
                    loss = loss_fn(outputs, batch["label"])
151

152
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
153

154
155
                if optim_type == OptimType.oss_sharded_ddp:
                    model.reduce()
156

157
158
159
160
161
162
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
163
164
                return loss

165
            if need_profiling and not args.cpu:
166
                logging.info("Profiling the run")
167
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
168
169
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
170
                        logging.info("profiling done")
171
172
173
174
175
176
177
178

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)
179

180
181
182
183
184
185
186
187
            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

188
189
190
191
            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start
192

193
        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
194
195
196
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
197
            optimizer.consolidate_state_dict()
198
            if dist.get_rank() == 0:
199
                _ = optimizer.state_dict()
200
                logging.info("... State dict collected")
201

202
        measurements.append(n_items / epoch_runtime)
203
        if dist.get_rank() == 0:
204
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
205

206
    max_memory = -1.0
207
208
    if not args.cpu:
        torch.cuda.synchronize(rank)
209
210
211
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

212
    training_stop = time.monotonic()
213
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
214
215
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

216
    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
217
    logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
218

219
220
221
    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
222
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
223
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
224

225
    if check_regression and dist.get_rank() == 0:
226
227
228
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
229

230
        logging.info("[Regression Test] VALID")
231

232
233
    dist.destroy_process_group()  # type: ignore

234
235
236
237
238
239
240

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark the optimizer state sharding, on a typical computer vision workload"
    )
    parser.add_argument("--world_size", action="store", default=2, type=int)
    parser.add_argument("--epochs", action="store", default=10, type=int)
241
    parser.add_argument("--batch_size", action="store", default=256, type=int)
242
    parser.add_argument("--check_regression", action="store_true", default=False)
243
244
245
    parser.add_argument("--reference_speed", action="store", default=1430, type=float)
    parser.add_argument("--reference_memory", action="store", default=1220, type=float)
    parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
246
247
248
    parser.add_argument(
        "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
    )
249
    parser.add_argument("--gloo", action="store_true", default=False)
250
    parser.add_argument("--profile", action="store_true", default=False)
251
    parser.add_argument("--cpu", action="store_true", default=False)
252
    parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
253
254
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
255
256

    args = parser.parse_args()
257

258
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
259
    logging.info(f"Benchmark arguments: {args}")
260

261
    backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    # Download dataset once for all processes
    dataset, tentatives = None, 0
    while dataset is None and tentatives < 5:
        try:
            dataset = MNIST(transform=None, download=True, root=TEMPDIR)
        except (RuntimeError, EOFError) as e:
            if isinstance(e, RuntimeError):
                # Corrupted data, erase and restart
                shutil.rmtree(TEMPDIR + "/MNIST")

            logging.warning("Failed loading dataset: ", e)
            tentatives += 1

    if dataset is None:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")

282
    # Benchmark the different configurations, via multiple processes
283
    if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
284
        logging.info("\n*** Benchmark vanilla optimizer")
285
286
        mp.spawn(
            train,
287
            args=(args, backend, OptimType.vanilla, False,),  # no regression check
288
289
290
291
            nprocs=args.world_size,
            join=True,
        )

292
293
    if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with DDP")
294
        mp.spawn(
295
            train, args=(args, backend, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
296
297
        )

298
299
    if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
        logging.info("\n*** Benchmark OSS with ShardedDDP")
300
301
        mp.spawn(
            train,
302
303
304
305
306
307
            args=(
                args,
                backend,
                OptimType.oss_sharded_ddp,
                False,
            ),  # FIXME: @lefaudeux - SDP should give the same results
308
309
310
            nprocs=args.world_size,
            join=True,
        )