pipe.py 16.4 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
5
6
from collections import defaultdict
from functools import reduce
import gc
7
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
8
import math
9
import operator
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
10
11
import time

12
13
14
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from models import transformer_lm
15
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
16
import torch
17
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
18
from torch.distributed import rpc
19
from torch.nn.parallel import DistributedDataParallel as DDP
20
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
21

22
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
23
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
24
from fairscale.nn.model_parallel import initialize_model_parallel
25
from fairscale.utils.testing import dist_init
26

27
28
29
MPI_PORT = 29500
RPC_PORT = 29501

30

Tom Birch's avatar
Tom Birch committed
31
32
33
34
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
35
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
36
37


anj-s's avatar
anj-s committed
38
def get_model_and_optimizer(args, device, benchmark_config, model_config):
39
40
41
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
42
        model = get_lm_model(args, device, model_config)
43

anj-s's avatar
anj-s committed
44
    lr = benchmark_config["lr"]
45
46

    def make_adam(params):
47
        return Adam(params, lr=lr)
48
49
50
51
52
53
54
55

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

56
57
58
59
60
61
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
62
    ndecoder = config["num_decoder_layers"]
63

Tom Birch's avatar
Tom Birch committed
64
65
    if args.lazy_construction:
        layers = [
66
67
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
68
69
        ]
        for _ in range(ndecoder):
70
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
71

72
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
73
74
        model = layers
    else:
75
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
76

77
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
78
79


Tom Birch's avatar
Tom Birch committed
80
81
82
83
84
85
86
87
88
89
90
91
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


92
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
93
94

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
95
    if hasattr(model, "group"):
96
97
98
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
99
        torch.distributed.all_reduce(total, group=model.group)
100
        logging.debug(
101
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
102
103
104
105
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
106
            logging.debug(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
107
    else:
108
        logging.debug(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
109
110


111
112
113
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
114

115
116
    if not torch.cuda.is_available():
        return torch.device("cpu")
117
    if hasattr(model, "devices"):
118
119
120
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
121

122

123
def get_fake_dataloader(lm_dataloader_len, args):
124
125
126
127
128
129
130
131
132
133
134
135
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
136
def train(model_config, model, benchmark_config, model_specs, args):
137
    lm_dataloader, _, _ = model_config["data"]
138
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
139
    vocab_size = model_specs["vocab_size"]
140
    optimizer = model_config["optimizer"]
141
142
143
144
145
146
147
148

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
149

150
    pipe_group = model.group if hasattr(model, "group") else None
151

152
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
153
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
154
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
155
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, device, benchmark_config, model_specs)
156

157
158
    total_tokens = 0
    total_tokens_per_log_interval = 0
159
160
161
162
163
164
165
166
167
168
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
169
    for i, batch in enumerate(lm_dataloader):
170
171
172
173
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
174
175
        if args.max_batch and i > args.max_batch:
            break
176
177
178

        if i > 0:
            total_tokens += source.numel()
179

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
180
        optimizer.zero_grad()
181
        try:
182
            if pipe_group is None or pipe_group.rank() == 0:
183
                tmp = source.to(get_device(model, 0))
184
185
                output = model(tmp)
            else:
186
                output = model(source)
187
188
189
190
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
191
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
192
193
194
            output = output.to(target.device)
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            loss.backward()
195
            del target
Tom Birch's avatar
Tom Birch committed
196
        else:
197
            model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
198
199
200

        del output

anj-s's avatar
anj-s committed
201
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
202
203
        optimizer.step()

204
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
205
206
            total_loss += loss.item()
            log_interval = 1
207
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
208
209
210
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
211
                if dist.get_rank() == dist.get_world_size() - 1:
212
                    logging.debug(
213
214
215
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
216
                    )
217
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
218
219
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
220

221
222
223
224
225
226
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
227
    if dist.get_rank() == dist.get_world_size() - 1:
228
229
230
        return wps, loss.item()
    else:
        return 0.0, 0.0
231

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
232

233
234
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
235
236
    eval_model.eval()
    total_loss = 0.0
237
238
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
239
240
241
242
243
244
245

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
246
247
248
249
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
250
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
251
252
253
254
255
256
257
258
259
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


260
def verify_peak_memory(rank, golden_config, std_dev):
261
262
263
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
264
265
266
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
267
        raise RuntimeError(
268
269
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
270
        )
271
272


273
274
275
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

276
    if dist.get_rank() == dist.get_world_size() - 1:
277
278
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
279
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
280
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
281
            raise RuntimeError(
282
283
284
285
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
286
            )
287

288
289
    for i in range(4):
        verify_peak_memory(i, golden_config, 1.1)
290

291

anj-s's avatar
anj-s committed
292
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
293
    golden_config = get_golden_config(args.model_name, args)
294
    epoch = benchmark_config["epochs"]
295
    start_time = time.time()
296
    if dist.get_rank() == dist.get_world_size() - 1:
297
298
299
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
300
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
301
    elapsed_time = time.time() - start_time
302
    if dist.get_rank() == dist.get_world_size() - 1:
303
304
305
306
307
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
308
309
310
311
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
312

313
    if len(model.balance) == 4:
314
        if args.model_name == "lm":
315
            verify_lm_run(wps, golden_config, args)
316
317
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
333

334
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
335
    """Returns dataloader for synthetic data."""
336

337
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
338
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
339
    else:
340
341
342
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
343
def get_real_dataloaders(args, device, benchmark_config, model_specs):
344
345
346
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
347
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
348
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
349
        model_specs["vocab_size"] = ntokens
350
        return train_dataloader, valid_dataloader, test_dataloader
351
352
353
354
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
355
def create_model_config(args, benchmark_config=None, model_specs=None):
356
357
358
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
359

360
    if args.use_synthetic_data:
361
        dataloader_fn = get_synthetic_dataloaders
362
    else:
363
364
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
365
366
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
367
368
369
370
371
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
372
373


374
375
376
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

377
    if model_name == "lm":
378
379
380
381
382
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
383
384
385
386
387
388
389
390
391
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


392
def get_golden_config(model_name, args):
393
394
395
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
396
        return lm_wikitext2.get_golden_real_stats()
397
398
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
399
400
401
402
403


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

404
405
406
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

407
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
408
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
409
410
    init_random_seed(0)

411
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
412
413
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
414
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
415

416
    balance = generate_balance(min(num_devices, 4), len(model))
417
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
418
    del model
419
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
420

421
    if args.dry_run:
anj-s's avatar
anj-s committed
422
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
423
    else:
anj-s's avatar
anj-s committed
424
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument(
444
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
Tom Birch's avatar
Tom Birch committed
445
446
)
parser.add_argument(
447
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
Tom Birch's avatar
Tom Birch committed
448
)
449
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
450
451
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
452
parser.add_argument(
453
454
455
456
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
457
)
458
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
Tom Birch's avatar
Tom Birch committed
459
460
461

if __name__ == "__main__":
    args = parser.parse_args()
462
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
463

464
465
    logging.info(f"Running single process benchmark with args: {args}")
    benchmark_single_process(args)