pipe.py 19.9 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
import operator
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
10
11
import time

12
13
14
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from models import transformer_lm
15
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
16
import torch
17
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
18
19
from torch.distributed import rpc
import torch.multiprocessing as mp
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
22

23
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
24
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
25
from fairscale.nn.model_parallel import initialize_model_parallel
26
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
27
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
28
from fairscale.optim.oss import OSS
29
from fairscale.utils.testing import dist_init, get_worker_map
30

31
32
33
MPI_PORT = 29500
RPC_PORT = 29501

34

Tom Birch's avatar
Tom Birch committed
35
36
37
38
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
39
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
40
41


anj-s's avatar
anj-s committed
42
def get_model_and_optimizer(args, device, benchmark_config, model_config):
43
44
45
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
46
        model = get_lm_model(args, device, model_config)
47

anj-s's avatar
anj-s committed
48
    lr = benchmark_config["lr"]
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def make_adam(params):
        if args.ddp_zero:
            return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
        else:
            return Adam(params, lr=lr)

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

63
64
65
66
67
68
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
69
    ndecoder = config["num_decoder_layers"]
70

Tom Birch's avatar
Tom Birch committed
71
72
    if args.lazy_construction:
        layers = [
73
74
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
75
76
        ]
        for _ in range(ndecoder):
77
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
78

79
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
80
81
        model = layers
    else:
82
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
83

84
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
85
86


Tom Birch's avatar
Tom Birch committed
87
88
89
90
91
92
93
94
95
96
97
98
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


99
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
100
101

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
102
    if hasattr(model, "group"):
103
104
105
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
106
        torch.distributed.all_reduce(total, group=model.group)
107
        logging.debug(
108
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
109
110
111
112
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
113
            logging.debug(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
114
    else:
115
        logging.debug(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
116
117


118
119
120
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
121

122
123
    if not torch.cuda.is_available():
        return torch.device("cpu")
124
    if hasattr(model, "devices"):
125
126
127
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
128

129

130
def get_fake_dataloader(lm_dataloader_len, args):
131
132
133
134
135
136
137
138
139
140
141
142
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
143
def train(model_config, model, benchmark_config, model_specs, args):
144
    lm_dataloader, _, _ = model_config["data"]
145
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
146
    vocab_size = model_specs["vocab_size"]
147
    optimizer = model_config["optimizer"]
148
149
150
151
152
153
154
155

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
156

157
    pipe_group = model.group if hasattr(model, "group") else None
158
159
160
161
162
163
164
165
166

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

167
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
168
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
anj-s's avatar
anj-s committed
169
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
170

171
172
    total_tokens = 0
    total_tokens_per_log_interval = 0
173
174
175
176
177
178
179
180
181
182
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
183
    for i, batch in enumerate(lm_dataloader):
184
185
186
187
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
188
189
        if args.max_batch and i > args.max_batch:
            break
190
191
192

        if i > 0:
            total_tokens += source.numel()
193

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
194
        optimizer.zero_grad()
195
196
        try:
            if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
197
                tmp = source.to(get_device(model, 0))
198
199
                output = model(tmp)
            else:
200
                output = model(source)
201
202
203
204
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
205
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
206
            output = output.to(target.device)
207

Tom Birch's avatar
Tom Birch committed
208
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
209
210
211
212
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
                loss /= ddp_group.size()
Tom Birch's avatar
Tom Birch committed
213
            loss.backward()
214
            del target
Tom Birch's avatar
Tom Birch committed
215
        else:
216
217
218
219
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
220
221
222

        del output

anj-s's avatar
anj-s committed
223
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
224
225
        optimizer.step()

226
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
227
228
            total_loss += loss.item()
            log_interval = 1
229
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
230
231
232
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
233
                if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
234
                    logging.debug(
235
236
237
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
238
                    )
239
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
240
241
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
242

243
244
245
246
247
248
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
249
250
251
252
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        return wps, loss.item()
    else:
        return 0.0, 0.0
253

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
254

255
256
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
257
258
    eval_model.eval()
    total_loss = 0.0
259
260
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
261
262
263
264
265
266
267

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
268
269
270
271
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
272
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
273
274
275
276
277
278
279
280
281
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


282
def verify_peak_memory(rank, golden_config, std_dev):
283
284
285
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
286
287
288
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
289
        raise RuntimeError(
290
291
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
292
        )
293
294


295
296
297
298
299
300
301
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Verify wps only on the last rank in multiprocess pipe
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
302
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
303
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
304
            raise RuntimeError(
305
306
307
308
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
309
            )
310

311
312
313
314
315
316
    if args.multiprocess:
        verify_peak_memory(dist.get_rank(), golden_config, 1.5)
    else:
        for i in range(4):
            verify_peak_memory(i, golden_config, 1.1)

317

anj-s's avatar
anj-s committed
318
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
319
    golden_config = get_golden_config(args.model_name, args)
320
    epoch = benchmark_config["epochs"]
321
    start_time = time.time()
322
    if dist.get_rank() == dist.get_world_size() - 1:
323
324
325
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
326
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
327
    elapsed_time = time.time() - start_time
328
    if dist.get_rank() == dist.get_world_size() - 1:
329
330
331
332
333
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
334
335
336
337
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
338

339
    if len(model.balance) == 4:
340
        if args.model_name == "lm":
341
            verify_lm_run(wps, golden_config, args)
342
343
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
359

anj-s's avatar
anj-s committed
360
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
361
    """Returns dataloader for synthetic data."""
362

363
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
364
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
365
    else:
366
367
368
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
369
def get_real_dataloaders(args, device, benchmark_config, model_specs):
370
371
372
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
373
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
374
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
375
        model_specs["vocab_size"] = ntokens
376
        return train_dataloader, valid_dataloader, test_dataloader
377
378
379
380
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
381
def create_model_config(args, benchmark_config=None, model_specs=None):
382
383
384
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
385

386
    if args.use_synthetic_data:
387
        dataloader_fn = get_synthetic_dataloaders
388
    else:
389
390
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
391
392
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
393
394
395
396
397
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
398
399


400
401
402
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

403
    if model_name == "lm":
404
405
406
407
408
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
409
410
411
412
413
414
415
416
417
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


418
def get_golden_config(model_name, args):
419
420
421
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
422
        return lm_wikitext2.get_golden_real_stats(args.multiprocess)
423
424
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
425
426
427
428
429


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

430
431
432
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

433
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
434
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
435
436
    init_random_seed(0)

437
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
438
439
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
440
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
441

442
    balance = generate_balance(min(num_devices, 4), len(model))
443
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
444
    del model
445
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
446

447
    if args.dry_run:
anj-s's avatar
anj-s committed
448
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
449
    else:
anj-s's avatar
anj-s committed
450
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
451
452
453
454


def run_mp_worker(args, available_workers):

455
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
456
457
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
458
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
459

460
    balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
461
    pipe_model = MultiProcessPipe(
Tom Birch's avatar
Tom Birch committed
462
463
464
465
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
466
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
467
        checkpoint=args.checkpoint,
468
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
469
470
    )
    if torch.cuda.is_available():
471
        pipe_model = pipe_model.cuda()
Tom Birch's avatar
Tom Birch committed
472

473
    if args.dry_run:
anj-s's avatar
anj-s committed
474
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
475
    else:
anj-s's avatar
anj-s committed
476
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


491
492
493
494
495
496
497
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
Tom Birch's avatar
Tom Birch committed
498

499
500
    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
Tom Birch's avatar
Tom Birch committed
501
502
503
504
505
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
506
507
508
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
Tom Birch's avatar
Tom Birch committed
509
    )
510
    initialize_model_parallel(1, world_size)
Tom Birch's avatar
Tom Birch committed
511
512
513
514
515
516
517
518
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
519
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
Tom Birch's avatar
Tom Birch committed
520
521
522
523
524
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
525
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
Tom Birch's avatar
Tom Birch committed
526
527
528
529
530
531
parser.add_argument(
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
532
533
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
534
parser.add_argument(
535
536
537
538
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
539
)
540
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
Tom Birch's avatar
Tom Birch committed
541
542
543

if __name__ == "__main__":
    args = parser.parse_args()
544
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
545
546

    if not args.multiprocess:
547
        logging.info(f"Running single process benchmark with args: {args}")
548
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
549
    else:
550
        world_size = max(torch.cuda.device_count(), 1)
551
        logging.info(f"Running multiprocess benchmark with args: {args}")
552
        mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)