pipe.py 19.1 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
import operator
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
10
11
import time

12
13
14
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from models import transformer_lm
15
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
16
import torch
17
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
18
19
from torch.distributed import rpc
import torch.multiprocessing as mp
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
22

23
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
24
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
25
from fairscale.nn.model_parallel import initialize_model_parallel
26
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
27
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
28
from fairscale.utils.testing import dist_init, get_worker_map
29

30
31
32
MPI_PORT = 29500
RPC_PORT = 29501

33

Tom Birch's avatar
Tom Birch committed
34
35
36
37
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
38
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
39
40


anj-s's avatar
anj-s committed
41
def get_model_and_optimizer(args, device, benchmark_config, model_config):
42
43
44
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
45
        model = get_lm_model(args, device, model_config)
46

anj-s's avatar
anj-s committed
47
    lr = benchmark_config["lr"]
48
49

    def make_adam(params):
50
        return Adam(params, lr=lr)
51
52
53
54
55
56
57
58

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

59
60
61
62
63
64
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
65
    ndecoder = config["num_decoder_layers"]
66

Tom Birch's avatar
Tom Birch committed
67
68
    if args.lazy_construction:
        layers = [
69
70
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
71
72
        ]
        for _ in range(ndecoder):
73
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
74

75
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
76
77
        model = layers
    else:
78
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
79

80
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
81
82


Tom Birch's avatar
Tom Birch committed
83
84
85
86
87
88
89
90
91
92
93
94
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


95
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
96
97

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
98
    if hasattr(model, "group"):
99
100
101
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
102
        torch.distributed.all_reduce(total, group=model.group)
103
        logging.debug(
104
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
105
106
107
108
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
109
            logging.debug(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
110
    else:
111
        logging.debug(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
112
113


114
115
116
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
117

118
119
    if not torch.cuda.is_available():
        return torch.device("cpu")
120
    if hasattr(model, "devices"):
121
122
123
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
124

125

126
def get_fake_dataloader(lm_dataloader_len, args):
127
128
129
130
131
132
133
134
135
136
137
138
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
139
def train(model_config, model, benchmark_config, model_specs, args):
140
    lm_dataloader, _, _ = model_config["data"]
141
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
142
    vocab_size = model_specs["vocab_size"]
143
    optimizer = model_config["optimizer"]
144
145
146
147
148
149
150
151

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
152

153
    pipe_group = model.group if hasattr(model, "group") else None
154

155
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
156
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
157
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
158
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, device, benchmark_config, model_specs)
159

160
161
    total_tokens = 0
    total_tokens_per_log_interval = 0
162
163
164
165
166
167
168
169
170
171
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
172
    for i, batch in enumerate(lm_dataloader):
173
174
175
176
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
177
178
        if args.max_batch and i > args.max_batch:
            break
179
180
181

        if i > 0:
            total_tokens += source.numel()
182

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
183
        optimizer.zero_grad()
184
        try:
185
            if pipe_group is None or pipe_group.rank() == 0:
186
                tmp = source.to(get_device(model, 0))
187
188
                output = model(tmp)
            else:
189
                output = model(source)
190
191
192
193
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
194
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
195
            output = output.to(target.device)
196

Tom Birch's avatar
Tom Birch committed
197
198
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            loss.backward()
199
            del target
Tom Birch's avatar
Tom Birch committed
200
        else:
201
            model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
202
203
204

        del output

anj-s's avatar
anj-s committed
205
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
206
207
        optimizer.step()

208
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
209
210
            total_loss += loss.item()
            log_interval = 1
211
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
212
213
214
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
215
                if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
216
                    logging.debug(
217
218
219
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
220
                    )
221
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
222
223
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
224

225
226
227
228
229
230
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
231
232
233
234
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        return wps, loss.item()
    else:
        return 0.0, 0.0
235

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
236

237
238
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
239
240
    eval_model.eval()
    total_loss = 0.0
241
242
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
243
244
245
246
247
248
249

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
250
251
252
253
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
254
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
255
256
257
258
259
260
261
262
263
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


264
def verify_peak_memory(rank, golden_config, std_dev):
265
266
267
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
268
269
270
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
271
        raise RuntimeError(
272
273
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
274
        )
275
276


277
278
279
280
281
282
283
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Verify wps only on the last rank in multiprocess pipe
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
284
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
285
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
286
            raise RuntimeError(
287
288
289
290
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
291
            )
292

293
294
295
296
297
298
    if args.multiprocess:
        verify_peak_memory(dist.get_rank(), golden_config, 1.5)
    else:
        for i in range(4):
            verify_peak_memory(i, golden_config, 1.1)

299

anj-s's avatar
anj-s committed
300
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
301
    golden_config = get_golden_config(args.model_name, args)
302
    epoch = benchmark_config["epochs"]
303
    start_time = time.time()
304
    if dist.get_rank() == dist.get_world_size() - 1:
305
306
307
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
308
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
309
    elapsed_time = time.time() - start_time
310
    if dist.get_rank() == dist.get_world_size() - 1:
311
312
313
314
315
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
316
317
318
319
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
320

321
    if len(model.balance) == 4:
322
        if args.model_name == "lm":
323
            verify_lm_run(wps, golden_config, args)
324
325
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
341

342
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
343
    """Returns dataloader for synthetic data."""
344

345
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
346
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
347
    else:
348
349
350
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
351
def get_real_dataloaders(args, device, benchmark_config, model_specs):
352
353
354
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
355
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
356
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
357
        model_specs["vocab_size"] = ntokens
358
        return train_dataloader, valid_dataloader, test_dataloader
359
360
361
362
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
363
def create_model_config(args, benchmark_config=None, model_specs=None):
364
365
366
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
367

368
    if args.use_synthetic_data:
369
        dataloader_fn = get_synthetic_dataloaders
370
    else:
371
372
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
373
374
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
375
376
377
378
379
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
380
381


382
383
384
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

385
    if model_name == "lm":
386
387
388
389
390
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
391
392
393
394
395
396
397
398
399
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


400
def get_golden_config(model_name, args):
401
402
403
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
404
        return lm_wikitext2.get_golden_real_stats(args.multiprocess)
405
406
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
407
408
409
410
411


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

412
413
414
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

415
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
416
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
417
418
    init_random_seed(0)

419
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
420
421
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
422
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
423

424
    balance = generate_balance(min(num_devices, 4), len(model))
425
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
426
    del model
427
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
428

429
    if args.dry_run:
anj-s's avatar
anj-s committed
430
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
431
    else:
anj-s's avatar
anj-s committed
432
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
433
434
435
436


def run_mp_worker(args, available_workers):

437
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
438
439
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
440
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
441

442
    balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
443
    pipe_model = MultiProcessPipe(
Tom Birch's avatar
Tom Birch committed
444
445
446
447
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
448
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
449
        checkpoint=args.checkpoint,
450
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
451
452
    )
    if torch.cuda.is_available():
453
        pipe_model = pipe_model.cuda()
Tom Birch's avatar
Tom Birch committed
454

455
    if args.dry_run:
anj-s's avatar
anj-s committed
456
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
457
    else:
anj-s's avatar
anj-s committed
458
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


473
474
475
476
477
478
479
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
Tom Birch's avatar
Tom Birch committed
480

481
482
    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
Tom Birch's avatar
Tom Birch committed
483
484
485
486
487
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
488
489
490
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
Tom Birch's avatar
Tom Birch committed
491
    )
492
    initialize_model_parallel(1, world_size)
Tom Birch's avatar
Tom Birch committed
493
494
495
496
497
498
499
500
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
501
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
Tom Birch's avatar
Tom Birch committed
502
503
504
505
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument(
506
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
Tom Birch's avatar
Tom Birch committed
507
508
)
parser.add_argument(
509
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
Tom Birch's avatar
Tom Birch committed
510
)
511
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
512
513
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
514
parser.add_argument(
515
516
517
518
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
519
)
520
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
Tom Birch's avatar
Tom Birch committed
521
522
523

if __name__ == "__main__":
    args = parser.parse_args()
524
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
525
526

    if not args.multiprocess:
527
        logging.info(f"Running single process benchmark with args: {args}")
528
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
529
    else:
530
        world_size = max(torch.cuda.device_count(), 1)
531
        logging.info(f"Running multiprocess benchmark with args: {args}")
532
        mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)