pipe.py 21 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
import operator
Tom Birch's avatar
Tom Birch committed
10
import os
11
import pprint
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
12
13
import time

Tom Birch's avatar
Tom Birch committed
14
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
15
16
import datasets
import models
17
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
18
import torch
Tom Birch's avatar
Tom Birch committed
19
20
from torch.distributed import rpc
import torch.multiprocessing as mp
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
21
import torch.nn as nn
22
from torch.nn.parallel import DistributedDataParallel as DDP
Tom Birch's avatar
Tom Birch committed
23
from torch.utils.data import DataLoader
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
24

25
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
26
from fairscale.nn.model_parallel import initialize_model_parallel
27
28
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
Jun Ru Anderson's avatar
Jun Ru Anderson committed
29
from fairscale.optim import GradScaler
30
from fairscale.optim.oss import OSS
31
from fairscale.utils.testing import dist_init, get_worker_map
32

Jun Ru Anderson's avatar
Jun Ru Anderson committed
33
try:
Tom Birch's avatar
Tom Birch committed
34
    from fairscale.optim import Adam  # type: ignore
Jun Ru Anderson's avatar
Jun Ru Anderson committed
35
36
37
38
39
40
41

    can_benchmark = True
except ImportError:
    from torch.optim import Adam  # type: ignore

    can_benchmark = False

42

Tom Birch's avatar
Tom Birch committed
43
44
45
46
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
47
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
48
49


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def get_model_and_optimizer(args, device, config):
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
        model = get_lm_model(args, device, config)

    lr = config["lr"]

    def make_adam(params):
        if args.ddp_zero:
            return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
        else:
            return Adam(params, lr=lr)

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

71
72
73
74
75
76
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
Tom Birch's avatar
Tom Birch committed
77
    ndecoder = args.num_decoder_layers
78

Tom Birch's avatar
Tom Birch committed
79
80
    if args.lazy_construction:
        layers = [
81
82
            LazyModule(lambda: models.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: models.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
83
84
        ]
        for _ in range(ndecoder):
85
            layers.append(LazyModule(lambda: models.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
86

87
        layers.append(LazyModule(lambda: models.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
88
89
        model = layers
    else:
90
        model = models.TransformerLMSequntial(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
91

92
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
93
94


Tom Birch's avatar
Tom Birch committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


def dump_size_buckets(size_buckets, prefix=""):

    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(prefix + f"{key} : {value}, {this}")

    print(prefix + f"total = {total}")


last_size_buckets = None
once = True


def safe_rank():
    try:
        return torch.distributed.get_rank()
    except AssertionError:
        return 0


def check_size_buckets():
    global last_size_buckets
    global once
    size_buckets = get_tensors_by_size_bucket()
    if last_size_buckets is not None:
        if size_buckets != last_size_buckets:
            print(f"difference is oustanding tensors: {safe-rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
        if once:
            print(f"dumping buckets for: {safe_rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
            once = False
    else:
        print(f"size buckets none on {safe_rank()}")
    last_size_buckets = size_buckets


def dump_cuda_tensors():
    print(f"dumping cuda tensors...")

    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    print(f"outstanding cuda tensors:")
    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(f"{key} : {value}, {this}")
    print(f"total size = {total}")
    pprint.pprint(torch.cuda.memory_stats())


167
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
168
169
170

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
    if model.group:
171
172
173
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
174
175
        torch.distributed.all_reduce(total, group=model.group)
        logging.info(
176
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
177
178
179
180
181
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logging.info(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
182
    else:
183
        logging.info(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
184
185


186
187
188
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
189

190
191
192
193
194
195
    if not torch.cuda.is_available():
        return torch.device("cpu")
    if model.devices:
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
196

197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def get_fake_dataloader(lm_dataloader_len):
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


def train(data_config, model, benchmark_config, args):
    lm_dataloader = data_config["data"]
    criterion = benchmark_config["criterion"]
    vocab_size = benchmark_config["vocab_size"]
    optimizer = data_config["optimizer"]

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    start_time = time.time()
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
225

226
227
228
229
230
231
232
233
234
235
    pipe_group = model.group

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

236
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
237
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
238
        lm_dataloader = get_fake_dataloader(len(lm_dataloader))
239

240
241
    total_tokens = 0
    total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
242
243
244
    for i, batch in enumerate(lm_dataloader):
        if args.max_batch and i > args.max_batch:
            break
245
246
        total_tokens += batch["input"].numel()

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
247
        optimizer.zero_grad()
248
249
        try:
            if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
250
                tmp = batch["input"].to(get_device(model, 0))
251
252
253
254
255
256
257
                output = model(tmp)
            else:
                output = model(batch["input"])
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
258
            target = batch["target"].to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
259
            output = output.to(target.device)
260

Tom Birch's avatar
Tom Birch committed
261
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
262
263
264
265
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
                loss /= ddp_group.size()
Tom Birch's avatar
Tom Birch committed
266
            loss.backward()
267
            del target
Tom Birch's avatar
Tom Birch committed
268
        else:
269
270
271
272
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
273
274
275

        del output

276
        torch.nn.utils.clip_grad_value_(model.parameters(), benchmark_config["clip_value"])
Tom Birch's avatar
Tom Birch committed
277
278
        optimizer.step()

279
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
280
281
            total_loss += loss.item()
            log_interval = 1
282
            total_tokens_per_log_interval += batch["input"].numel()
Tom Birch's avatar
Tom Birch committed
283
284
285
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
Jun Ru Anderson's avatar
Jun Ru Anderson committed
286
                print(
Tom Birch's avatar
Tom Birch committed
287
                    "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
288
                        i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
289
                    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
290
                )
291
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
292
293
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
294

295
296
    return total_tokens, loss.item()

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
297

298
299
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
300
301
    eval_model.eval()
    total_loss = 0.0
302
303
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
304
305
306
307
308
309
310

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
311
312
313
314
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
315
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
316
317
318
319
320
321
322
323
324
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def verify_lm_run(wps):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Assert that words per second is within 3 standard deviations of the average
    # of six golden runs
    assert wps > 36954.4 - (3 * 116.825)

    for i in range(4):
        print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"]))

    # Assert that memory usage on each GPU is within 10% of golden run
    # Right-hand-side is golden run bytes * 110%
    for i, golden_ref in zip(range(4), [4061909504, 4050944, 10427392, 2031824896]):
        assert torch.cuda.memory_stats(i)["allocated_bytes.all.peak"] < golden_ref * 1.1


341
342
343
344
def benchmark_language_model(model_config, model, benchmark_config, args):
    ntokens, train_data, val_data, test_data = model_config["data"]
    optimizer = model_config["optimizer"]
    criterion = benchmark_config["criterion"]
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
345
346
    epoch = 1

Jun Ru Anderson's avatar
Jun Ru Anderson committed
347
    print("-" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
348
    print("| start of epoch {:1d}".format(epoch))
Jun Ru Anderson's avatar
Jun Ru Anderson committed
349
    print("-" * 110)
350
351
    start_time = time.time()
    n_words, loss = train(data_config, model, benchmark_config, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
352
353
    elapsed_time = time.time() - start_time
    wps = nwords / elapsed_time
354
355
356
    print("-" * 110)
    print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
    print("-" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
357

Jun Ru Anderson's avatar
Jun Ru Anderson committed
358
    if can_benchmark and len(model.balance) == 4:
359

360
361
362
363
        if args.model_name == "lm":
            verify_lm_run(wps)
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
364
365


366
367
368
369
370
371
372
373
374
375
376
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
    balance = []
    layers_assigned = 0
    average_count = num_layers / num_devices
    last_layers = int(average_count * fraction)

    balance = generate_balance(num_devices - 1, num_layers - last_layers)
    balance.append(last_layers)
    return balance


377
378
379
380
381
382
383
384
385
386
387
388
389
def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
390

391
392
def get_synthetic_dataloader(args):
    """Returns dataloader for synthetic data."""
393

394
    if args.model_name == "lm":
Tom Birch's avatar
Tom Birch committed
395
396
397
398
        lm_dataset = BenchmarkLMDataset()
        lm_dataloader = DataLoader(
            lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
        )
399
        return lm_dataloader
Tom Birch's avatar
Tom Birch committed
400
    else:
401
402
403
404
405
406
407
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


def get_real_dataloaders(device, config):
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
408
409
410
        data = datasets.get_wikitext2_data(device)
        ntokens, _, _, _ = data
        config["vocab_size"] = ntokens
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        return data
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


def create_model_config(args, config=None):
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if args.use_synthetic_data:
        model, optimizer = get_model_and_optimizer(args, device, config)
        dataloader = get_synthetic_dataloader(args)
        return {"model": model, "optimizer": optimizer, "data": dataloader}
    else:
        data = get_real_dataloaders(device, config)
        model, optimizer = get_model_and_optimizer(args, device, config)
Tom Birch's avatar
Tom Birch committed
427
428
429
430
431
432
433
        return {
            "model": model,
            "optimizer": optimizer,
            "data": data,
        }


434
435
436
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

437
    if model_name == "lm":
438
439
440
441
442
443
444
445
446
447
448
449
        return {
            "vocab_size": 10000,
            "ninp": 2048,  # embedding dimension
            "nhid": 2048,  # the dimension of the feedforward network model in nn.TransformerEncoder
            "nhead": 32,  # the number of heads in the multiheadattention models
            "dropout": 0,
            "initrange": 0.1,
            "criterion": nn.CrossEntropyLoss(),
            "lr": 0.01,  # learning rate
            "scaler": GradScaler(),
            "clip_value": 0.05,
        }
450
451
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
452
453
454
455
456


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

457
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
458
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
459
460
    init_random_seed(0)

461
    benchmark_config = create_benchmark_config(args.model_name)
462
    model_config = create_model_config(args, config=benchmark_config)
463
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
464

465
    balance = generate_balance(min(num_devices, 4), len(model))
466
    pipe_model = pipe.Pipe(
Tom Birch's avatar
Tom Birch committed
467
468
        model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
    )
469
    del model
470
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
471

472
    if args.use_synthetic_data:
473
        train(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
474
    else:
475
        benchmark_language_model(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
476
477
478
479


def run_mp_worker(args, available_workers):

480
    benchmark_config = create_benchmark_config(args.model_name)
481
    model_config = create_model_config(args, config=benchmark_config)
482
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
483

484
    balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
485
    pipe_model = pipe.Pipe(
Tom Birch's avatar
Tom Birch committed
486
487
        model,
        balance,
488
        style=Pipe.AsyncSchedule,
Tom Birch's avatar
Tom Birch committed
489
490
        chunks=args.chunks,
        worker_map=get_worker_map(),
491
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
492
493
        pipelined_backward=args.pipelined_backward,
        checkpoint=args.checkpoint,
494
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
495
496
    )
    if torch.cuda.is_available():
497
498
        pipe_model = pipe_model.cuda()
    if args.all_at_once and pipe_model.pipeline:
Tom Birch's avatar
Tom Birch committed
499
        print(f"running all at once")
500
        pipe_model.pipeline.all_at_once = True
Tom Birch's avatar
Tom Birch committed
501

502
    if args.use_synthetic_data:
503
        train(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
504
    else:
505
        benchmark_language_model(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


def bench_multi_process(args, all_at_once=False):
    if args.local_world_size != 0:
        world_size = args.local_world_size
    else:
        world_size = min(torch.cuda.device_count(), 2)
    mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)


best_device_map = {
    0: "mlx5_0:1",
    1: "mlx5_0:1",
    2: "mlx5_1:1",
    3: "mlx5_1:1",
    4: "mlx5_2:1",
    5: "mlx5_2:1",
    6: "mlx5_3:1",
    7: "mlx5_3:1",
}


def bench_mpi(args):
    guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
542
543
544
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
    os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]
Tom Birch's avatar
Tom Birch committed
545
546

    os.environ["MASTER_ADDR"] = args.host
547
    os.environ["MASTER_PORT"] = "10638"
Tom Birch's avatar
Tom Birch committed
548
549
550
    if args.socket_name:
        os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
        os.environ["TP_SOCKET_IFNAME"] = args.socket_name
551
552
553
554
555

    torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)

    os.environ["MASTER_ADDR"] = args.host
    os.environ["MASTER_PORT"] = "10639"
Tom Birch's avatar
Tom Birch committed
556
557
558
    init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
559
    torch.cuda.set_device(local_rank % torch.cuda.device_count())
Tom Birch's avatar
Tom Birch committed
560
561
562
563
564
565
566
567
568

    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
    )

569
570
571
572
573
574
    backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}

    if args.ddp_zero:
        initialize_model_parallel(1, 4, **backends)
    else:
        initialize_model_parallel(1, world_size, **backends)
Tom Birch's avatar
Tom Birch committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    init_random_seed(0)

    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
595
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
Tom Birch's avatar
Tom Birch committed
596
597
598
599
600
601
602
603
604
605
606
607
parser.add_argument(
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
    "--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
    "--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
608
609
parser.add_argument("--use_synthetic_data", default=True, help="Uses synthetic data for a sample training run.")
parser.add_argument(
610
611
612
613
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
614
)
Tom Birch's avatar
Tom Birch committed
615
616
617
618
parser.set_defaults(pipelined_backward=True)

if __name__ == "__main__":
    args = parser.parse_args()
619
    # TODO(anj-s): Add support for multiprocess benchmarking.
Tom Birch's avatar
Tom Birch committed
620
621
    if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
        print(f"Running benchmark with args: {args}")
622
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
623
624
625
626
    else:
        if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
            print(f"Running benchmark with args: {args}")
        bench_mpi(args)