pipe.py 21.8 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
10
import operator
import pprint
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
12
import time

13
14
15
16
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from golden_configs import lm_wikitext2
from models import transformer_lm
17
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
18
import torch
19
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
20
21
from torch.distributed import rpc
import torch.multiprocessing as mp
22
from torch.nn.parallel import DistributedDataParallel as DDP
23
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
24

25
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
26
from fairscale.nn.model_parallel import initialize_model_parallel
27
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
28
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
29
from fairscale.optim.oss import OSS
30
from fairscale.utils.testing import dist_init, get_worker_map
31

32
33
34
MPI_PORT = 29500
RPC_PORT = 29501

35

Tom Birch's avatar
Tom Birch committed
36
37
38
39
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
40
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
41
42


anj-s's avatar
anj-s committed
43
def get_model_and_optimizer(args, device, benchmark_config, model_config):
44
45
46
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
47
        model = get_lm_model(args, device, model_config)
48

anj-s's avatar
anj-s committed
49
    lr = benchmark_config["lr"]
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    def make_adam(params):
        if args.ddp_zero:
            return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
        else:
            return Adam(params, lr=lr)

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

64
65
66
67
68
69
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
70
    ndecoder = config["num_decoder_layers"]
71

Tom Birch's avatar
Tom Birch committed
72
73
    if args.lazy_construction:
        layers = [
74
75
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
76
77
        ]
        for _ in range(ndecoder):
78
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
79

80
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
81
82
        model = layers
    else:
83
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
84

85
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
86
87


Tom Birch's avatar
Tom Birch committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


def dump_size_buckets(size_buckets, prefix=""):

    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(prefix + f"{key} : {value}, {this}")

    print(prefix + f"total = {total}")


last_size_buckets = None
once = True


def safe_rank():
    try:
        return torch.distributed.get_rank()
    except AssertionError:
        return 0


def check_size_buckets():
    global last_size_buckets
    global once
    size_buckets = get_tensors_by_size_bucket()
    if last_size_buckets is not None:
        if size_buckets != last_size_buckets:
            print(f"difference is oustanding tensors: {safe-rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
        if once:
            print(f"dumping buckets for: {safe_rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
            once = False
    else:
        print(f"size buckets none on {safe_rank()}")
    last_size_buckets = size_buckets


def dump_cuda_tensors():
    print(f"dumping cuda tensors...")

    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    print(f"outstanding cuda tensors:")
    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(f"{key} : {value}, {this}")
    print(f"total size = {total}")
    pprint.pprint(torch.cuda.memory_stats())


160
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
161
162

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
163
    if hasattr(model, "group"):
164
165
166
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
167
168
        torch.distributed.all_reduce(total, group=model.group)
        logging.info(
169
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
170
171
172
173
174
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logging.info(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
175
    else:
176
        logging.info(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
177
178


179
180
181
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
182

183
184
    if not torch.cuda.is_available():
        return torch.device("cpu")
185
    if hasattr(model, "devices"):
186
187
188
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
189

190

191
def get_fake_dataloader(lm_dataloader_len, args):
192
193
194
195
196
197
198
199
200
201
202
203
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
204
def train(model_config, model, benchmark_config, model_specs, args):
205
    lm_dataloader, _, _ = model_config["data"]
206
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
207
    vocab_size = model_specs["vocab_size"]
208
    optimizer = model_config["optimizer"]
209
210
211
212
213
214
215
216

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
217

218
    pipe_group = model.group if hasattr(model, "group") else None
219
220
221
222
223
224
225
226
227

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

228
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
229
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
anj-s's avatar
anj-s committed
230
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
231

232
233
    total_tokens = 0
    total_tokens_per_log_interval = 0
234
235
236
237
238
239
240
241
242
243
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
244
    for i, batch in enumerate(lm_dataloader):
245
246
247
248
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
249
250
        if args.max_batch and i > args.max_batch:
            break
251
252
253

        if i > 0:
            total_tokens += source.numel()
254

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
255
        optimizer.zero_grad()
256
257
        try:
            if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
258
                tmp = source.to(get_device(model, 0))
259
260
                output = model(tmp)
            else:
261
                output = model(source)
262
263
264
265
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
266
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
267
            output = output.to(target.device)
268

Tom Birch's avatar
Tom Birch committed
269
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
270
271
272
273
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
                loss /= ddp_group.size()
Tom Birch's avatar
Tom Birch committed
274
            loss.backward()
275
            del target
Tom Birch's avatar
Tom Birch committed
276
        else:
277
278
279
280
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
281
282
283

        del output

anj-s's avatar
anj-s committed
284
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
285
286
        optimizer.step()

287
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
288
289
            total_loss += loss.item()
            log_interval = 1
290
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
291
292
293
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
294
295
296
297
298
                if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
                    print(
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
299
                    )
300
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
301
302
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
303

304
305
306
307
308
309
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
310
311
312
313
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        return wps, loss.item()
    else:
        return 0.0, 0.0
314

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
315

316
317
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
318
319
    eval_model.eval()
    total_loss = 0.0
320
321
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
322
323
324
325
326
327
328

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
329
330
331
332
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
333
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
334
335
336
337
338
339
340
341
342
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


343
344
345
346
347
def verify_peak_memory(rank, golden_config, std_dev):
    print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]))
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
348
        raise RuntimeError(
349
350
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
351
        )
352
353


354
355
356
357
358
359
360
361
362
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Verify wps only on the last rank in multiprocess pipe
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
        print("Throughput(wps) is {:.2f}.".format(wps))
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
363
            raise RuntimeError(
364
365
366
367
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
368
            )
369

370
371
372
373
374
375
    if args.multiprocess:
        verify_peak_memory(dist.get_rank(), golden_config, 1.5)
    else:
        for i in range(4):
            verify_peak_memory(i, golden_config, 1.1)

376

anj-s's avatar
anj-s committed
377
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
378
    golden_config = get_golden_config(args.model_name, args)
379
    epoch = benchmark_config["epochs"]
380
    start_time = time.time()
381
382
383
384
    if dist.get_rank() == dist.get_world_size() - 1:
        print("-" * 110)
        print("| start of epoch {:1d}".format(epoch))
        print("-" * 110)
anj-s's avatar
anj-s committed
385
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
386
    elapsed_time = time.time() - start_time
387
388
389
390
391
392
393
394
395
396
    if dist.get_rank() == dist.get_world_size() - 1:
        print("-" * 110)
        print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        print("-" * 110)
        print("Throughput(wps) is {:.2f}.".format(wps))
    print(
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
397

398
    if len(model.balance) == 4:
399
        if args.model_name == "lm":
400
            verify_lm_run(wps, golden_config, args)
401
402
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
403
404


405
406
407
408
409
410
411
412
413
414
415
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
    balance = []
    layers_assigned = 0
    average_count = num_layers / num_devices
    last_layers = int(average_count * fraction)

    balance = generate_balance(num_devices - 1, num_layers - last_layers)
    balance.append(last_layers)
    return balance


416
417
418
419
420
421
422
423
424
425
426
427
428
def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
429

anj-s's avatar
anj-s committed
430
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
431
    """Returns dataloader for synthetic data."""
432

433
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
434
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
435
    else:
436
437
438
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
439
def get_real_dataloaders(args, device, benchmark_config, model_specs):
440
441
442
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
443
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
444
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
445
        model_specs["vocab_size"] = ntokens
446
        return train_dataloader, valid_dataloader, test_dataloader
447
448
449
450
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
451
def create_model_config(args, benchmark_config=None, model_specs=None):
452
453
454
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
455

456
    if args.use_synthetic_data:
457
        dataloader_fn = get_synthetic_dataloaders
458
    else:
459
460
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
461
462
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
463
464
465
466
467
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
468
469


470
471
472
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

473
    if model_name == "lm":
474
475
476
477
478
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
479
480
481
482
483
484
485
486
487
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


488
def get_golden_config(model_name, args):
489
490
491
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
492
        return lm_wikitext2.get_golden_real_stats(args.multiprocess)
493
494
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
495
496
497
498
499


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

500
501
502
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

503
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
504
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
505
506
    init_random_seed(0)

507
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
508
509
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
510
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
511

512
    balance = generate_balance(min(num_devices, 4), len(model))
513
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
514
    del model
515
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
516

517
    if args.dry_run:
anj-s's avatar
anj-s committed
518
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
519
    else:
anj-s's avatar
anj-s committed
520
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
521
522
523
524


def run_mp_worker(args, available_workers):

525
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
526
527
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
528
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
529

530
    balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
531
    pipe_model = MultiProcessPipe(
Tom Birch's avatar
Tom Birch committed
532
533
534
535
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
536
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
537
        checkpoint=args.checkpoint,
538
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
539
540
    )
    if torch.cuda.is_available():
541
542
        pipe_model = pipe_model.cuda()
    if args.all_at_once and pipe_model.pipeline:
Tom Birch's avatar
Tom Birch committed
543
        print(f"running all at once")
544
        pipe_model.pipeline.all_at_once = True
Tom Birch's avatar
Tom Birch committed
545

546
    if args.dry_run:
anj-s's avatar
anj-s committed
547
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
548
    else:
anj-s's avatar
anj-s committed
549
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


564
565
566
567
568
569
570
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
Tom Birch's avatar
Tom Birch committed
571

572
573
    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
Tom Birch's avatar
Tom Birch committed
574
575
576
577
578
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
579
580
581
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
Tom Birch's avatar
Tom Birch committed
582
    )
583
    initialize_model_parallel(1, world_size)
Tom Birch's avatar
Tom Birch committed
584
585
586
587
588
589
590
591
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
592
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
Tom Birch's avatar
Tom Birch committed
593
594
595
596
597
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
598
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
Tom Birch's avatar
Tom Birch committed
599
600
601
602
603
604
parser.add_argument(
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
605
606
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
607
parser.add_argument(
608
609
610
611
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
612
)
Tom Birch's avatar
Tom Birch committed
613
614
615

if __name__ == "__main__":
    args = parser.parse_args()
616
617
618
619
620

    # TODO(anj-s): Remove print statements and introduce logging levels.

    if not args.multiprocess:
        print(f"Running single process benchmark with args: {args}")
621
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
622
    else:
623
624
625
        world_size = max(torch.cuda.device_count(), 1)
        print(f"Running multiprocess benchmark with args: {args}")
        mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)