pipe.py 16.5 KB
Newer Older
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
5

Tom Birch's avatar
Tom Birch committed
6
import argparse
7
8
9
from collections import defaultdict
from functools import reduce
import gc
10
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
import math
12
import operator
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
13
14
import time

15
16
17
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from models import transformer_lm
18
import numpy as np
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
19
import torch
20
import torch.distributed as dist
Tom Birch's avatar
Tom Birch committed
21
from torch.distributed import rpc
22
from torch.nn.parallel import DistributedDataParallel as DDP
23
from torch.optim import Adam
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
24

25
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
26
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
27
from fairscale.nn.model_parallel import initialize_model_parallel
28
from fairscale.utils.testing import dist_init
29

30
31
32
MPI_PORT = 29500
RPC_PORT = 29501

33

Tom Birch's avatar
Tom Birch committed
34
35
36
37
def init_random_seed(seed: int):

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
38
    np.random.seed(seed)
Tom Birch's avatar
Tom Birch committed
39
40


anj-s's avatar
anj-s committed
41
def get_model_and_optimizer(args, device, benchmark_config, model_config):
42
43
44
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
45
        model = get_lm_model(args, device, model_config)
46

anj-s's avatar
anj-s committed
47
    lr = benchmark_config["lr"]
48
49

    def make_adam(params):
50
        return Adam(params, lr=lr)
51
52
53
54
55
56
57
58

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

59
60
61
62
63
64
    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
65
    ndecoder = config["num_decoder_layers"]
66

Tom Birch's avatar
Tom Birch committed
67
68
    if args.lazy_construction:
        layers = [
69
70
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
71
72
        ]
        for _ in range(ndecoder):
73
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
74

75
        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
Tom Birch's avatar
Tom Birch committed
76
77
        model = layers
    else:
78
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
79

80
    return model
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
81
82


Tom Birch's avatar
Tom Birch committed
83
84
85
86
87
88
89
90
91
92
93
94
def get_tensors_by_size_bucket():

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


95
def log_number_of_parameters(model):
Tom Birch's avatar
Tom Birch committed
96
97

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
98
    if hasattr(model, "group"):
99
100
101
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
102
        torch.distributed.all_reduce(total, group=model.group)
103
        logging.debug(
104
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
105
106
107
108
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
109
            logging.debug(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
110
    else:
111
        logging.debug(f"training model, #params = {num_params}")
Tom Birch's avatar
Tom Birch committed
112
113


114
115
116
def get_device(model, index):
    if isinstance(model, DDP):
        model = model.module
117

118
119
    if not torch.cuda.is_available():
        return torch.device("cpu")
120
    if hasattr(model, "devices"):
121
122
123
        return model.devices[index]
    else:
        return torch.cuda.current_device()
Tom Birch's avatar
Tom Birch committed
124

125

126
def get_fake_dataloader(lm_dataloader_len, args):
127
128
129
130
131
132
133
134
135
136
137
138
    fake_input = {"input": torch.zeros(args.batch_size)}

    class FakeDataset:
        def __getitem__(self, index):
            return fake_input

        def __len__(self):
            return lm_dataloader_len

    return FakeDataset()


anj-s's avatar
anj-s committed
139
def train(model_config, model, benchmark_config, model_specs, args):
140
    lm_dataloader, _, _ = model_config["data"]
141
    criterion = benchmark_config["criterion"]
anj-s's avatar
anj-s committed
142
    vocab_size = model_specs["vocab_size"]
143
    optimizer = model_config["optimizer"]
144
145
146
147
148
149
150
151

    model.train()
    log_number_of_parameters(model)

    total_loss = 0.0
    word_counter = 0

    optimizer = optimizer(model.parameters())
Tom Birch's avatar
Tom Birch committed
152

153
    pipe_group = model.group if hasattr(model, "group") else None
154

155
    # TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
156
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
157
    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
158
        lm_dataloader, _, _ = get_synthetic_dataloaders(args, device, benchmark_config, model_specs)
159

160
161
    total_tokens = 0
    total_tokens_per_log_interval = 0
162
163
164
165
166
167
168
169
170
171
    bptt = 2
    start_time = time.time()
    epoch_start_time = 0.0

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

Tom Birch's avatar
Tom Birch committed
172
    for i, batch in enumerate(lm_dataloader):
173
174
175
176
        if i == 1:
            epoch_start_time = time.time()

        source, target = get_batch(batch)
Tom Birch's avatar
Tom Birch committed
177
178
        if args.max_batch and i > args.max_batch:
            break
179
180
181

        if i > 0:
            total_tokens += source.numel()
182

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
183
        optimizer.zero_grad()
184
        try:
185
            if pipe_group is None or pipe_group.rank() == 0:
186
                tmp = source.to(get_device(model, 0))
187
188
                output = model(tmp)
            else:
189
                output = model(source)
190
191
192
193
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
194
            target = target.to(get_device(model, -1))
Tom Birch's avatar
Tom Birch committed
195
196
197
            output = output.to(target.device)
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            loss.backward()
198
            del target
Tom Birch's avatar
Tom Birch committed
199
        else:
200
            model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
201
202
203

        del output

anj-s's avatar
anj-s committed
204
        torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
Tom Birch's avatar
Tom Birch committed
205
206
        optimizer.step()

207
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
208
209
            total_loss += loss.item()
            log_interval = 1
210
            total_tokens_per_log_interval += source.numel()
Tom Birch's avatar
Tom Birch committed
211
212
213
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
214
                if dist.get_rank() == dist.get_world_size() - 1:
215
                    logging.debug(
216
217
218
                        "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                            i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                        )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
219
                    )
220
                total_tokens_per_log_interval = 0
Tom Birch's avatar
Tom Birch committed
221
222
                total_loss = 0
                start_time = time.time()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
223

224
225
226
227
228
229
    if epoch_start_time != 0:
        wps = total_tokens / (time.time() - epoch_start_time)
    else:
        raise RuntimeError(
            "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
        )
230
    if dist.get_rank() == dist.get_world_size() - 1:
231
232
233
        return wps, loss.item()
    else:
        return 0.0, 0.0
234

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
235

236
237
# TODO(anj-s): Add an option for users to be able to benchmark evaluate.
def evaluate(eval_model, data_source, criterion, ntokens):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
238
239
    eval_model.eval()
    total_loss = 0.0
240
241
    # TODO(anj-s): Move this to the benchmark config if we want to benchmark evaluation.
    bptt = 35
242
243
244
245
246
247
248

    def get_batch(source, i, bptt):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i : i + seq_len]
        target = source[i + 1 : i + 1 + seq_len].view(-1)
        return data, target

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
249
250
251
252
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
253
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
254
255
256
257
258
259
260
261
262
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


263
def verify_peak_memory(rank, golden_config, std_dev):
264
265
266
    logging.debug(
        "Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
    )
267
268
269
    current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
    golden_ref = golden_config["peak_mem_usage"][rank]
    if not current_device_usage < golden_ref * std_dev:
270
        raise RuntimeError(
271
272
            "Peak memory usage for cuda device {:d} is {:d} which"
            "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
273
        )
274
275


276
277
278
def verify_lm_run(wps, golden_config, args):
    """Verify that words per second for a given benchmark run matches the golden data."""

279
    if dist.get_rank() == dist.get_world_size() - 1:
280
281
        # Assert that words per second is within 3 standard deviations of the average
        # of five golden runs
282
        logging.info("Throughput(wps) is {:.2f}.".format(wps))
283
        if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
284
            raise RuntimeError(
285
286
287
288
                "Throughput(wps):{:.2f} is below the golden threshold of an "
                "average value of {:.2f} and standard dev of {:.2f}.".format(
                    wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
                )
289
            )
290

291
292
    for i in range(4):
        verify_peak_memory(i, golden_config, 1.1)
293

294

anj-s's avatar
anj-s committed
295
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
296
    golden_config = get_golden_config(args.model_name, args)
297
    epoch = benchmark_config["epochs"]
298
    start_time = time.time()
299
    if dist.get_rank() == dist.get_world_size() - 1:
300
301
302
        logging.debug("-" * 110)
        logging.debug("| start of epoch {:1d}".format(epoch))
        logging.debug("-" * 110)
anj-s's avatar
anj-s committed
303
    wps, loss = train(model_config, model, benchmark_config, model_specs, args)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
304
    elapsed_time = time.time() - start_time
305
    if dist.get_rank() == dist.get_world_size() - 1:
306
307
308
309
310
        logging.debug("-" * 110)
        logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
        logging.debug("-" * 110)
        logging.debug("Throughput(wps) is {:.2f}.".format(wps))
    logging.debug(
311
312
313
314
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
315

316
    if len(model.balance) == 4:
317
        if args.model_name == "lm":
318
            verify_lm_run(wps, golden_config, args)
319
320
        else:
            raise RuntimeError("Unrecognized args.model_name " % args.model_name)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
336

337
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
338
    """Returns dataloader for synthetic data."""
339

340
    if args.model_name == "lm":
anj-s's avatar
anj-s committed
341
        return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
Tom Birch's avatar
Tom Birch committed
342
    else:
343
344
345
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
346
def get_real_dataloaders(args, device, benchmark_config, model_specs):
347
348
349
    """Returns dataloaders for real data."""

    if args.model_name == "lm":
anj-s's avatar
anj-s committed
350
        data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
351
        ntokens, train_dataloader, valid_dataloader, test_dataloader = data
anj-s's avatar
anj-s committed
352
        model_specs["vocab_size"] = ntokens
353
        return train_dataloader, valid_dataloader, test_dataloader
354
355
356
357
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
358
def create_model_config(args, benchmark_config=None, model_specs=None):
359
360
361
    """Return a dict with the given model, dataset and optimizer."""

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
362

363
    if args.use_synthetic_data:
364
        dataloader_fn = get_synthetic_dataloaders
365
    else:
366
367
        dataloader_fn = get_real_dataloaders

anj-s's avatar
anj-s committed
368
369
    data = dataloader_fn(args, device, benchmark_config, model_specs)
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
370
371
372
373
374
    return {
        "model": model,
        "optimizer": optimizer,
        "data": data,
    }
Tom Birch's avatar
Tom Birch committed
375
376


377
378
379
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

380
    if model_name == "lm":
381
382
383
384
385
        return lm_wikitext2.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


anj-s's avatar
anj-s committed
386
387
388
389
390
391
392
393
394
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


395
def get_golden_config(model_name, args):
396
397
398
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
399
        return lm_wikitext2.get_golden_real_stats()
400
401
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
402
403
404
405
406


def benchmark_single_process(args):
    """Benchmark a given model using a single process and multiple devices."""

407
408
409
    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup)

410
    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
411
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
412
413
    init_random_seed(0)

414
    benchmark_config = create_benchmark_config(args.model_name)
anj-s's avatar
anj-s committed
415
416
    model_specs = get_model_specs(args.model_name)
    model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
417
    model = model_config["model"]
Tom Birch's avatar
Tom Birch committed
418

419
    balance = generate_balance(min(num_devices, 4), len(model))
420
    pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
421
    del model
422
    del model_config["model"]
Tom Birch's avatar
Tom Birch committed
423

424
    if args.dry_run:
anj-s's avatar
anj-s committed
425
        train(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
426
    else:
anj-s's avatar
anj-s committed
427
        benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
Tom Birch's avatar
Tom Birch committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument(
447
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
Tom Birch's avatar
Tom Birch committed
448
449
)
parser.add_argument(
450
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
Tom Birch's avatar
Tom Birch committed
451
)
452
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
453
454
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
455
parser.add_argument(
456
457
458
459
    # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
    "--model_name",
    default="lm",
    help="Language Model(LM) used to benchmark nn.pipe.",
460
)
461
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
Tom Birch's avatar
Tom Birch committed
462
463
464

if __name__ == "__main__":
    args = parser.parse_args()
465
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
466

467
468
    logging.info(f"Running single process benchmark with args: {args}")
    benchmark_single_process(args)