pipe.py 19.1 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
import operator
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
10
11
import time

12
13
14
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from models import transformer_lm
15
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
16
import torch
17
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
18
19
from torch.distributed import rpc
import torch.multiprocessing as mp
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
22

23
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
24
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
25
from fairscale.nn.model_parallel import initialize_model_parallel
26
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
27
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
28
from fairscale.utils.testing import dist_init, get_worker_map
29

30
31
32
MPI_PORT = 29500
RPC_PORT = 29501

33

Tom Birch's avatar
Tom Birch committed
34
35
36
37
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
38
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
39
40


anj-s's avatar
anj-s committed
41
def get_model_and_optimizer(args, device, benchmark_config, model_config):
42
43
44
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
45
        model = get_lm_model(args, device, model_config)
46

anj-s's avatar
anj-s committed
47
    lr = benchmark_config["lr"]
48
49

    def make_adam(params):
50
        return Adam(params, lr=lr)
51
52
53
54
55
56
57
58

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

59
60
61
62
63
64
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
65
    ndecoder = config["num_decoder_layers"]
66

Tom Birch's avatar
Tom Birch committed
67
68
    if args.lazy_construction:
        layers = [
69
70
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
71
72
        ]
        for _ in range(ndecoder):
73
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
74

75
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
76
77
        model = layers
    else:
78
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
79

80
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
81
82


Tom Birch's avatar
Tom Birch committed
83
84
85
86
87
88
89
90
91
92
93
94
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


95
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
96
97

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
98
    if hasattr(model, "group"):
99
100
101
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
102
        torch.distributed.all_reduce(total, group=model.group)
103
        logging.debug(
104
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
105
106
107
108
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
109
            logging.debug(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
110
    else:
111
        logging.debug(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
112
113


114
115
116
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
117

118
119
    if not torch.cuda.is_available():
        return torch.device("cpu")
120
    if hasattr(model, "devices"):
121
122
123
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
124

125

126
def get_fake_dataloader(lm_dataloader_len, args):
127
128
129
130
131
132
133
134
135
136
137
138
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
139
def train(model_config, model, benchmark_config, model_specs, args):
140
    lm_dataloader, _, _ = model_config["data"]
141
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
142
    vocab_size = model_specs["vocab_size"]
143
    optimizer = model_config["optimizer"]
144
145
146
147
148
149
150
151

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
152

153
    pipe_group = model.group if hasattr(model, "group") else None
154

155
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
156
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
157
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
158
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, device, benchmark_config, model_specs)
159

160
161
    total_tokens = 0
    total_tokens_per_log_interval = 0
162
163
164
165
166
167
168
169
170
171
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
172
    for i, batch in enumerate(lm_dataloader):
173
174
175
176
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
177
178
        if args.max_batch and i > args.max_batch:
            break
179
180
181

        if i > 0:
            total_tokens += source.numel()
182

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
183
        optimizer.zero_grad()
184
        try:
185
            if pipe_group is None or pipe_group.rank() == 0:
186
                tmp = source.to(get_device(model, 0))
187
188
                output = model(tmp)
            else:
189
                output = model(source)
190
191
192
193
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
194
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
195
196
197
            output = output.to(target.device)
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            loss.backward()
198
            del target
Tom Birch's avatar
Tom Birch committed
199
        else:
200
            model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
201
202
203

        del output

anj-s's avatar
anj-s committed
204
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
205
206
        optimizer.step()

207
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
208
209
            total_loss += loss.item()
            log_interval = 1
210
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
211
212
213
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
214
                if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
215
                    logging.debug(
216
217
218
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
219
                    )
220
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
221
222
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
223

224
225
226
227
228
229
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
230
231
232
233
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        return wps, loss.item()
    else:
        return 0.0, 0.0
234

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
235

236
237
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
238
239
    eval_model.eval()
    total_loss = 0.0
240
241
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
242
243
244
245
246
247
248

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
249
250
251
252
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
253
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
254
255
256
257
258
259
260
261
262
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


263
def verify_peak_memory(rank, golden_config, std_dev):
264
265
266
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
267
268
269
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
270
        raise RuntimeError(
271
272
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
273
        )
274
275


276
277
278
279
280
281
282
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

    # Verify wps only on the last rank in multiprocess pipe
    if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
283
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
284
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
285
            raise RuntimeError(
286
287
288
289
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
290
            )
291

292
293
294
295
296
297
    if args.multiprocess:
        verify_peak_memory(dist.get_rank(), golden_config, 1.5)
    else:
        for i in range(4):
            verify_peak_memory(i, golden_config, 1.1)

298

anj-s's avatar
anj-s committed
299
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
300
    golden_config = get_golden_config(args.model_name, args)
301
    epoch = benchmark_config["epochs"]
302
    start_time = time.time()
303
    if dist.get_rank() == dist.get_world_size() - 1:
304
305
306
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
307
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
308
    elapsed_time = time.time() - start_time
309
    if dist.get_rank() == dist.get_world_size() - 1:
310
311
312
313
314
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
315
316
317
318
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
319

320
    if len(model.balance) == 4:
321
        if args.model_name == "lm":
322
            verify_lm_run(wps, golden_config, args)
323
324
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
340

341
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
342
    """Returns dataloader for synthetic data."""
343

344
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
345
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
346
    else:
347
348
349
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
350
def get_real_dataloaders(args, device, benchmark_config, model_specs):
351
352
353
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
354
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
355
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
356
        model_specs["vocab_size"] = ntokens
357
        return train_dataloader, valid_dataloader, test_dataloader
358
359
360
361
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
362
def create_model_config(args, benchmark_config=None, model_specs=None):
363
364
365
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
366

367
    if args.use_synthetic_data:
368
        dataloader_fn = get_synthetic_dataloaders
369
    else:
370
371
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
372
373
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
374
375
376
377
378
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
379
380


381
382
383
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

384
    if model_name == "lm":
385
386
387
388
389
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
390
391
392
393
394
395
396
397
398
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


399
def get_golden_config(model_name, args):
400
401
402
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
403
        return lm_wikitext2.get_golden_real_stats(args.multiprocess)
404
405
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
406
407
408
409
410


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

411
412
413
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

414
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
415
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
416
417
    init_random_seed(0)

418
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
419
420
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
421
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
422

423
    balance = generate_balance(min(num_devices, 4), len(model))
424
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
425
    del model
426
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
427

428
    if args.dry_run:
anj-s's avatar
anj-s committed
429
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
430
    else:
anj-s's avatar
anj-s committed
431
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
432
433
434
435


def run_mp_worker(args, available_workers):

436
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
437
438
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
439
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
440

441
    balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
442
    pipe_model = MultiProcessPipe(
Tom Birch's avatar
Tom Birch committed
443
444
445
446
        model,
        balance,
        chunks=args.chunks,
        worker_map=get_worker_map(),
447
        input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
Tom Birch's avatar
Tom Birch committed
448
        checkpoint=args.checkpoint,
449
        # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
450
451
    )
    if torch.cuda.is_available():
452
        pipe_model = pipe_model.cuda()
Tom Birch's avatar
Tom Birch committed
453

454
    if args.dry_run:
anj-s's avatar
anj-s committed
455
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
456
    else:
anj-s's avatar
anj-s committed
457
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


472
473
474
475
476
477
478
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
Tom Birch's avatar
Tom Birch committed
479

480
481
    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
Tom Birch's avatar
Tom Birch committed
482
483
484
485
486
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
487
488
489
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
Tom Birch's avatar
Tom Birch committed
490
    )
491
    initialize_model_parallel(1, world_size)
Tom Birch's avatar
Tom Birch committed
492
493
494
495
496
497
498
499
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
500
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
Tom Birch's avatar
Tom Birch committed
501
502
503
504
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument(
505
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
Tom Birch's avatar
Tom Birch committed
506
507
)
parser.add_argument(
508
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
Tom Birch's avatar
Tom Birch committed
509
)
510
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
511
512
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
513
parser.add_argument(
514
515
516
517
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
518
)
519
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
Tom Birch's avatar
Tom Birch committed
520
521
522

if __name__ == "__main__":
    args = parser.parse_args()
523
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
524
525

    if not args.multiprocess:
526
        logging.info(f"Running single process benchmark with args: {args}")
527
        benchmark_single_process(args)
Tom Birch's avatar
Tom Birch committed
528
    else:
529
        world_size = max(torch.cuda.device_count(), 1)
530
        logging.info(f"Running multiprocess benchmark with args: {args}")
531
        mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)