pipe.py 21.2 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
10
import operator
import pprint
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
12
import time

13
14
15
16
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from golden_configs import lm_wikitext2
from models import transformer_lm
17
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
18
import torch
19
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
20
21
from torch.distributed import rpc
import torch.multiprocessing as mp
22
from torch.nn.parallel import DistributedDataParallel as DDP
23
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
24

25
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
26
from fairscale.nn.model_parallel import initialize_model_parallel
27
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
28
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
29
from fairscale.optim.oss import OSS
30
from fairscale.utils.testing import dist_init, get_worker_map
31

32
33
34
MPI_PORT = 29500
RPC_PORT = 29501

35

Tom Birch's avatar
Tom Birch committed
36
37
38
39
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
40
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get_model_and_optimizer(args, device, config):
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
        model = get_lm_model(args, device, config)

    lr = config["lr"]

    def make_adam(params):
        if args.ddp_zero:
            return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
        else:
            return Adam(params, lr=lr)

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

64
65
66
67
68
69
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
70
    ndecoder = config["num_decoder_layers"]
71

Tom Birch's avatar
Tom Birch committed
72
73
    if args.lazy_construction:
        layers = [
74
75
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
76
77
        ]
        for _ in range(ndecoder):
78
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
79

80
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
81
82
        model = layers
    else:
83
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
84

85
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
86
87


Tom Birch's avatar
Tom Birch committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


def dump_size_buckets(size_buckets, prefix=""):

    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(prefix + f"{key} : {value}, {this}")

    print(prefix + f"total = {total}")


last_size_buckets = None
once = True


def safe_rank():
    try:
        return torch.distributed.get_rank()
    except AssertionError:
        return 0


def check_size_buckets():
    global last_size_buckets
    global once
    size_buckets = get_tensors_by_size_bucket()
    if last_size_buckets is not None:
        if size_buckets != last_size_buckets:
            print(f"difference is oustanding tensors: {safe-rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
        if once:
            print(f"dumping buckets for: {safe_rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
            once = False
    else:
        print(f"size buckets none on {safe_rank()}")
    last_size_buckets = size_buckets


def dump_cuda_tensors():
    print(f"dumping cuda tensors...")

    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    print(f"outstanding cuda tensors:")
    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(f"{key} : {value}, {this}")
    print(f"total size = {total}")
    pprint.pprint(torch.cuda.memory_stats())


160
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
161
162

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
163
    if hasattr(model, "group"):
164
165
166
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
167
168
        torch.distributed.all_reduce(total, group=model.group)
        logging.info(
169
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
170
171
172
173
174
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logging.info(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
175
    else:
176
        logging.info(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
177
178


179
180
181
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
182

183
184
    if not torch.cuda.is_available():
        return torch.device("cpu")
185
    if hasattr(model, "devices"):
186
187
188
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
189

190

191
def get_fake_dataloader(lm_dataloader_len, args):
192
193
194
195
196
197
198
199
200
201
202
203
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


204
205
def train(model_config, model, benchmark_config, args):
    lm_dataloader, _, _ = model_config["data"]
206
207
    criterion = benchmark_config["criterion"]
    vocab_size = benchmark_config["vocab_size"]
208
    optimizer = model_config["optimizer"]
209
210
211
212
213
214
215
216

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
217

218
    pipe_group = model.group if hasattr(model, "group") else None
219
220
221
222
223
224
225
226
227

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

228
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
229
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
230
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config)
231

232
233
    total_tokens = 0
    total_tokens_per_log_interval = 0
234
235
236
237
238
239
240
241
242
243
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
244
    for i, batch in enumerate(lm_dataloader):
245
246
247
248
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
249
250
        if args.max_batch and i > args.max_batch:
            break
251
252
253

        if i > 0:
            total_tokens += source.numel()
254

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
255
        optimizer.zero_grad()
256
257
        try:
            if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
258
                tmp = source.to(get_device(model, 0))
259
260
                output = model(tmp)
            else:
261
                output = model(source)
262
263
264
265
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
266
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
267
            output = output.to(target.device)
268

Tom Birch's avatar
Tom Birch committed
269
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
270
271
272
273
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
                loss /= ddp_group.size()
Tom Birch's avatar
Tom Birch committed
274
            loss.backward()
275
            del target
Tom Birch's avatar
Tom Birch committed
276
        else:
277
278
279
280
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
281
282
283

        del output

284
        torch.nn.utils.clip_grad_value_(model.parameters(), benchmark_config["clip_value"])
Tom Birch's avatar
Tom Birch committed
285
286
        optimizer.step()

287
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
288
289
            total_loss += loss.item()
            log_interval = 1
290
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
291
292
293
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
294
295
296
297
298
                if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
                    print(
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
299
                    )
300
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
301
302
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
303

304
305
306
307
308
309
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
310
311
312
313
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        return wps, loss.item()
    else:
        return 0.0, 0.0
314

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
315

316
317
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
318
319
    eval_model.eval()
    total_loss = 0.0
320
321
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
322
323
324
325
326
327
328

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
329
330
331
332
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
333
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
334
335
336
337
338
339
340
341
342
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


343
344
345
346
347
def verify_peak_memory(rank, golden_config, std_dev):
    print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]))
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
348
        raise RuntimeError(
349
350
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
351
        )
352
353


354
355
356
357
358
359
360
361
362
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Verify wps only on the last rank in multiprocess pipe
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
        print("Throughput(wps) is {:.2f}.".format(wps))
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
363
            raise RuntimeError(
364
365
366
367
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
368
            )
369

370
371
372
373
374
375
    if args.multiprocess:
        verify_peak_memory(dist.get_rank(), golden_config, 1.5)
    else:
        for i in range(4):
            verify_peak_memory(i, golden_config, 1.1)

376

377
def benchmark_language_model(model_config, model, benchmark_config, args):
378
    golden_config = get_golden_config(args.model_name, args)
379
    epoch = benchmark_config["epochs"]
380
    start_time = time.time()
381
382
383
384
    if dist.get_rank() == dist.get_world_size() - 1:
        print("-" * 110)
        print("| start of epoch {:1d}".format(epoch))
        print("-" * 110)
385
    wps, loss = train(model_config, model, benchmark_config, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
386
    elapsed_time = time.time() - start_time
387
388
389
390
391
392
393
394
395
396
    if dist.get_rank() == dist.get_world_size() - 1:
        print("-" * 110)
        print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        print("-" * 110)
        print("Throughput(wps) is {:.2f}.".format(wps))
    print(
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
397

398
    if len(model.balance) == 4:
399
        if args.model_name == "lm":
400
            verify_lm_run(wps, golden_config, args)
401
402
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
403
404


405
406
407
408
409
410
411
412
413
414
415
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
    balance = []
    layers_assigned = 0
    average_count = num_layers / num_devices
    last_layers = int(average_count * fraction)

    balance = generate_balance(num_devices - 1, num_layers - last_layers)
    balance.append(last_layers)
    return balance


416
417
418
419
420
421
422
423
424
425
426
427
428
def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
429

430
def get_synthetic_dataloaders(args, benchmark_config):
431
    """Returns dataloader for synthetic data."""
432

433
    if args.model_name == "lm":
434
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config)
Tom Birch's avatar
Tom Birch committed
435
    else:
436
437
438
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


439
def get_real_dataloaders(args, device, benchmark_config):
440
441
442
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
443
        data = get_real_wikitext2_dataloaders(args, benchmark_config)
444
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
445
        benchmark_config["vocab_size"] = ntokens
446
        return train_dataloader, valid_dataloader, test_dataloader
447
448
449
450
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


451
def create_model_config(args, benchmark_config=None):
452
453
454
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
455

456
    if args.use_synthetic_data:
457
        dataloader_fn = get_synthetic_dataloaders
458
    else:
459
460
461
462
463
464
465
466
467
        dataloader_fn = get_real_dataloaders

    data = dataloader_fn(args, device, benchmark_config)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config)
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
468
469


470
471
472
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

473
    if model_name == "lm":
474
475
476
477
478
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


479
def get_golden_config(model_name, args):
480
481
482
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
483
        return lm_wikitext2.get_golden_real_stats(args.multiprocess)
484
485
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
486
487
488
489
490


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

491
492
493
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

494
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
495
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
496
497
    init_random_seed(0)

498
    benchmark_config = create_benchmark_config(args.model_name)
499
    model_config = create_model_config(args, benchmark_config=benchmark_config)
500
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
501

502
    balance = generate_balance(min(num_devices, 4), len(model))
503
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
504
    del model
505
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
506

507
    if args.dry_run:
508
        train(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
509
    else:
510
        benchmark_language_model(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
511
512
513
514


def run_mp_worker(args, available_workers):

515
    benchmark_config = create_benchmark_config(args.model_name)
516
    model_config = create_model_config(args, benchmark_config=benchmark_config)
517
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
518

519
    balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
520
    pipe_model = MultiProcessPipe(
Tom Birch's avatar
Tom Birch committed
521
522
523
524
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
525
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
526
        checkpoint=args.checkpoint,
527
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
528
529
    )
    if torch.cuda.is_available():
530
531
        pipe_model = pipe_model.cuda()
    if args.all_at_once and pipe_model.pipeline:
Tom Birch's avatar
Tom Birch committed
532
        print(f"running all at once")
533
        pipe_model.pipeline.all_at_once = True
Tom Birch's avatar
Tom Birch committed
534

535
    if args.dry_run:
536
        train(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
537
    else:
538
        benchmark_language_model(model_config, pipe_model, benchmark_config, args)
Tom Birch's avatar
Tom Birch committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


553
554
555
556
557
558
559
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
Tom Birch's avatar
Tom Birch committed
560

561
562
    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
Tom Birch's avatar
Tom Birch committed
563
564
565
566
567
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
568
569
570
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
Tom Birch's avatar
Tom Birch committed
571
    )
572
    initialize_model_parallel(1, world_size)
Tom Birch's avatar
Tom Birch committed
573
574
575
576
577
578
579
580
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
581
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
Tom Birch's avatar
Tom Birch committed
582
583
584
585
586
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
587
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
Tom Birch's avatar
Tom Birch committed
588
589
590
591
592
593
parser.add_argument(
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
594
595
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
596
parser.add_argument(
597
598
599
600
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
601
)
Tom Birch's avatar
Tom Birch committed
602
603
604

if __name__ == "__main__":
    args = parser.parse_args()
605
606
607
608
609

    # TODO(anj-s): Remove print statements and introduce logging levels.

    if not args.multiprocess:
        print(f"Running single process benchmark with args: {args}")
610
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
611
    else:
612
613
614
        world_size = max(torch.cuda.device_count(), 1)
        print(f"Running multiprocess benchmark with args: {args}")
        mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)