pipe.py 23.2 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Tom Birch's avatar
Tom Birch committed
3
import argparse
4
import logging
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
5
import math
Tom Birch's avatar
Tom Birch committed
6
import os
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
7
import time
Tom Birch's avatar
Tom Birch committed
8
import warnings
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
9

Tom Birch's avatar
Tom Birch committed
10
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
import torch
Tom Birch's avatar
Tom Birch committed
12
13
from torch.distributed import rpc
import torch.multiprocessing as mp
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
14
import torch.nn as nn
15
from torch.nn.parallel import DistributedDataParallel as DDP
Tom Birch's avatar
Tom Birch committed
16
from torch.utils.data import DataLoader
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
17
18
19
import torchtext
from torchtext.data.utils import get_tokenizer

20
from fairscale.nn import Pipe
Tom Birch's avatar
Tom Birch committed
21
from fairscale.nn.model_parallel import initialize_model_parallel
22
23
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
Jun Ru Anderson's avatar
Jun Ru Anderson committed
24
from fairscale.optim import GradScaler
25
from fairscale.optim.oss import OSS
Tom Birch's avatar
Tom Birch committed
26
from tests.nn.model_parallel.commons import dist_init, get_worker_map
27

Jun Ru Anderson's avatar
Jun Ru Anderson committed
28
try:
Tom Birch's avatar
Tom Birch committed
29
    from fairscale.optim import Adam  # type: ignore
Jun Ru Anderson's avatar
Jun Ru Anderson committed
30
31
32
33
34
35
36

    can_benchmark = True
except ImportError:
    from torch.optim import Adam  # type: ignore

    can_benchmark = False

37

Tom Birch's avatar
Tom Birch committed
38
39
40
41
42
43
44
45
46
47
48
49
def init_random_seed(seed: int):
    import numpy

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    numpy.random.seed(seed)


PIPE_CHUNKS = 2
iteration_count = 0


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class EmbeddingLayer(nn.Embedding):
    def __init__(self, ntoken, ninp, initrange):
        super().__init__(ntoken, ninp)
        self.ninp = ninp
        self.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        return super().forward(src) * math.sqrt(self.ninp)


class PositionalEncodingLayer(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncodingLayer, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class TransformerDecoderLayer(nn.TransformerEncoderLayer):
    """Though this class inherits from torch.nn.TransformerEncoderLayer,
Tom Birch's avatar
Tom Birch committed
80
    it functions as a decoder in this model"""
81
82
83
84
85
86
87
88
89
90
91

    def __init__(self, ninp, nhead, nhid, droupout):
        super().__init__(ninp, nhead, nhid, droupout)
        self.src_mask = None

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src):
Tom Birch's avatar
Tom Birch committed
92
93
94
95
96
        global iteration_count
        iteration_count += 1
        # if iteration_count == 196:
        #    dump_cuda_tensors()

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        return super().forward(src, self.src_mask)


class LinearLayer(nn.Linear):
    def __init__(self, ninp, ntoken, initrange):
        super().__init__(ninp, ntoken)
        self.bias.data.zero_()
        self.weight.data.uniform_(-initrange, initrange)


class TransformerLMSequntial(nn.Sequential):
    """A small language model based on the design of GPT-2 using nn.Sequeitnal
Tom Birch's avatar
Tom Birch committed
114
    for compatability with Pipe"""
115

Tom Birch's avatar
Tom Birch committed
116
117
    def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
        layers = [
118
119
            EmbeddingLayer(ntokens, ninp, initrange),
            PositionalEncodingLayer(ninp, dropout),
Tom Birch's avatar
Tom Birch committed
120
121
122
123
124
125
        ]
        for _ in range(ndecoder):
            layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))

        layers.append(LinearLayer(ninp, ntokens, initrange))
        super(TransformerLMSequntial, self).__init__(*layers)
126

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
127
128

def get_data(device):
Tom Birch's avatar
Tom Birch committed
129
130
131
132
133
134
135
    with warnings.catch_warnings(record=True) as fjldska:
        TEXT = torchtext.data.Field(
            tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
        )
        train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
        TEXT.build_vocab(train_txt)
        ntokens = len(TEXT.vocab.stoi)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
136

Tom Birch's avatar
Tom Birch committed
137
138
139
140
141
        batch_size = 20
        eval_batch_size = 10
        train_data = batchify(train_txt, batch_size, TEXT, device)
        val_data = batchify(val_txt, eval_batch_size, TEXT, device)
        test_data = batchify(test_txt, eval_batch_size, TEXT, device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
142

Tom Birch's avatar
Tom Birch committed
143
        return ntokens, train_data, val_data, test_data
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160


def batchify(data, bsz, TEXT, device):
    data = TEXT.numericalize([data.examples[0].text])
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def get_batch(source, i, bptt):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len].view(-1)
    return data, target


Tom Birch's avatar
Tom Birch committed
161
162
163
164
def make_model(args, device, ntokens):
    ninp = 2048  # embedding dimension
    nhid = 2048  # the dimension of the feedforward network model in nn.TransformerEncoder
    nhead = 32  # the number of heads in the multiheadattention models
165
166
    dropout = 0
    initrange = 0.1
Tom Birch's avatar
Tom Birch committed
167
    ndecoder = args.num_decoder_layers
168

Tom Birch's avatar
Tom Birch committed
169
170
    if args.lazy_construction:
        layers = [
171
172
            LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
            LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
Tom Birch's avatar
Tom Birch committed
173
174
        ]
        for _ in range(ndecoder):
175
            layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
Tom Birch's avatar
Tom Birch committed
176

177
        layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
Tom Birch's avatar
Tom Birch committed
178
179
180
        model = layers
    else:
        model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
181

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
182
    criterion = nn.CrossEntropyLoss()
Tom Birch's avatar
Tom Birch committed
183
    lr = 0.01  # learning rate
184

Tom Birch's avatar
Tom Birch committed
185
    def make_adam(model):
186
187
188
189
        if args.ddp_zero:
            return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
        else:
            return Adam(model.parameters(), lr=lr)
Tom Birch's avatar
Tom Birch committed
190
191

    optimizer = make_adam
Jun Ru Anderson's avatar
Jun Ru Anderson committed
192
    scaler = GradScaler()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
193

Tom Birch's avatar
Tom Birch committed
194
    return model, criterion, optimizer, scaler
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
195
196


Tom Birch's avatar
Tom Birch committed
197
198
def get_tensors_by_size_bucket():
    from collections import defaultdict
199
    import gc
Tom Birch's avatar
Tom Birch committed
200
201
202
203
204
205
206
207
208
209
210
211
212

    size_buckets = defaultdict(int)
    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    return size_buckets


def dump_size_buckets(size_buckets, prefix=""):
    from functools import reduce
213
    import operator
Tom Birch's avatar
Tom Birch committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(prefix + f"{key} : {value}, {this}")

    print(prefix + f"total = {total}")


last_size_buckets = None
once = True


def safe_rank():
    try:
        return torch.distributed.get_rank()
    except AssertionError:
        return 0


def check_size_buckets():
    global last_size_buckets
    global once
    size_buckets = get_tensors_by_size_bucket()
    if last_size_buckets is not None:
        if size_buckets != last_size_buckets:
            print(f"difference is oustanding tensors: {safe-rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
        if once:
            print(f"dumping buckets for: {safe_rank()}")
            dump_size_buckets(last_size_buckets, "old: ")
            dump_size_buckets(size_buckets, "new: ")
            once = False
    else:
        print(f"size buckets none on {safe_rank()}")
    last_size_buckets = size_buckets


def dump_cuda_tensors():
    print(f"dumping cuda tensors...")
    from functools import reduce
    import gc
258
    import operator
Tom Birch's avatar
Tom Birch committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    for obj in gc.get_objects():
        if not isinstance(obj, torch.Tensor):
            continue
        if obj.device.type == "cuda":
            size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1

    print(f"outstanding cuda tensors:")
    total = 0
    for key, value in size_buckets.items():
        this = reduce(operator.mul, key) * value
        total += this
        print(f"{key} : {value}, {this}")
    print(f"total size = {total}")

    import pprint

    pprint.pprint(torch.cuda.memory_stats())


def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
280
    model.train()
Tom Birch's avatar
Tom Birch committed
281
282
283
284
285
    from functools import reduce
    import operator

    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
    if model.group:
286
287
288
289
290
291
292
293
294
        total = torch.Tensor([num_params]).cuda()
        torch.distributed.all_reduce(total, group=model.group)
        logging.info(
            f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logging.info(f"total #prams = {total.item()}")
Tom Birch's avatar
Tom Birch committed
295
    else:
296
        logging.info(f"training model, #prams = {num_params}")
Tom Birch's avatar
Tom Birch committed
297
    vocab_size = 10000  # FIXME
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
298
299
    total_loss = 0.0
    start_time = time.time()
Tom Birch's avatar
Tom Birch committed
300
301
302
303
304
    word_counter = 0

    optimizer = optimizer(model)

    def get_first_device(model):
305
306
307
        if isinstance(model, DDP):
            model = model.module

Tom Birch's avatar
Tom Birch committed
308
309
310
311
312
313
        if model.devices:
            return model.devices[0]
        else:
            return torch.cuda.current_device()

    def get_last_device(model):
314
315
        if isinstance(model, DDP):
            model = model.module
Tom Birch's avatar
Tom Birch committed
316
317
318
319
320
        if model.devices:
            return model.devices[-1]
        else:
            return torch.cuda.current_device()

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    pipe_group = model.group

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
        thing = {"input": torch.zeros(args.batch_size)}

        class FakeDataset:
            def __getitem__(self, index):
                return thing

            def __len__(self):
                return len(lm_dataloader)

        lm_dataloader = FakeDataset()

Tom Birch's avatar
Tom Birch committed
343
    for i, batch in enumerate(lm_dataloader):
344
        bi = batch["input"]
Tom Birch's avatar
Tom Birch committed
345
346
        if args.max_batch and i > args.max_batch:
            break
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
347
        optimizer.zero_grad()
348
349
350
351
352
353
354
355
356
357
        try:
            if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
                tmp = batch["input"].to(get_first_device(model))
                output = model(tmp)
            else:
                output = model(batch["input"])
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
358
359
            target = batch["target"].to(get_last_device(model))
            output = output.to(target.device)
360

Tom Birch's avatar
Tom Birch committed
361
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
362
363
364
365
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
                loss /= ddp_group.size()
Tom Birch's avatar
Tom Birch committed
366
            loss.backward()
367
            del target
Tom Birch's avatar
Tom Birch committed
368
        else:
369
370
371
372
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)
Tom Birch's avatar
Tom Birch committed
373
374
375
376
377
378

        del output

        torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
        optimizer.step()

379
        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
Tom Birch's avatar
Tom Birch committed
380
381
382
383
384
385
            total_loss += loss.item()
            log_interval = 1
            word_counter += batch["ntokens"]
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
Jun Ru Anderson's avatar
Jun Ru Anderson committed
386
                print(
Tom Birch's avatar
Tom Birch committed
387
388
                    "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                        i, word_counter / elapsed, cur_loss, math.exp(cur_loss)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
389
                    )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
390
                )
Tom Birch's avatar
Tom Birch committed
391
392
393
394
395
396
397
                word_counter = 0
                total_loss = 0
                start_time = time.time()
        # if i >= 10:
        #    break
        # torch.cuda.empty_cache()
        # check_size_buckets()
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
398
399
400
401
402
403
404
405
406


def evaluate(eval_model, data_source, criterion, bptt, ntokens):
    eval_model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
407
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
408
409
410
411
412
413
414
415
416
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


Tom Birch's avatar
Tom Birch committed
417
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args):
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
418
419
420
421
    epoch = 1
    bptt = 35
    start_time = time.time()

Jun Ru Anderson's avatar
Jun Ru Anderson committed
422
    print("-" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
423
    print("| start of epoch {:1d}".format(epoch))
Jun Ru Anderson's avatar
Jun Ru Anderson committed
424
    print("-" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
425
    epoch_start_time = time.time()
Tom Birch's avatar
Tom Birch committed
426
427
428
    train(train_data, model, criterion, optimizer, bptt, ntokens, args)
    val_loss = 1  # evaluate(model, val_data, criterion, bptt, ntokens)
    print("-" * 89)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
429
    print(
430
431
432
        "| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
            epoch, (time.time() - epoch_start_time), val_loss
        )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
433
    )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
434
    print("-" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
435
436
437
438
439

    elapsed_time = time.time() - start_time
    nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
    wps = nwords / elapsed_time

Tom Birch's avatar
Tom Birch committed
440
441
    test_loss = 1  # evaluate(model, test_data, criterion, bptt, ntokens)
    print("=" * 89)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
442
    print(
443
444
        "| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
            test_loss, elapsed_time, nwords, wps
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
445
446
        )
    )
Jun Ru Anderson's avatar
Jun Ru Anderson committed
447
    print("=" * 110)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
448

Jun Ru Anderson's avatar
Jun Ru Anderson committed
449
    if can_benchmark and len(model.balance) == 4:
450
        # Assert that words per second is within 3 standard deviations of the average
Jun Ru Anderson's avatar
Jun Ru Anderson committed
451
        # of six golden runs
452
        assert wps > 36954.4 - (3 * 116.825)
453
454
455
456
457
458
459

        print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:2: {:1d}".format(torch.cuda.memory_stats(2)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:3: {:1d}".format(torch.cuda.memory_stats(3)["allocated_bytes.all.peak"]))

        # Assert that memory usage on each GPU is within 10% of golden run
460
        # Right-hand-side is golden run bytes * 110%
461
462
463
464
        assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 4061909504 * 1.1
        assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 4050944 * 1.1
        assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 10427392 * 1.1
        assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 2031824896 * 1.1
465
466
467
        print("No regression detected")


468
469
470
471
472
473
474
475
476
477
478
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
    balance = []
    layers_assigned = 0
    average_count = num_layers / num_devices
    last_layers = int(average_count * fraction)

    balance = generate_balance(num_devices - 1, num_layers - last_layers)
    balance.append(last_layers)
    return balance


479
480
481
482
483
484
485
486
487
488
489
490
491
def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
492

Tom Birch's avatar
Tom Birch committed
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def make_model_and_data(args, device, new_data: bool = True):
    if new_data:
        device = torch.device("cuda")
        vocab_size = 10000
        model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
        lm_dataset = BenchmarkLMDataset()
        lm_dataloader = DataLoader(
            lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
        )
        return {
            "model": model,
            "criterion": criterion,
            "optimizer": optimizer,
            "data": lm_dataloader,
            "vocab_size": vocab_size,
        }
    else:
        device = torch.device("cuda")
        data = get_data(device)
        ntokens, train_data, val_data, test_data = data
        model, criterion, optimizer, scaler = make_model(args, device, ntokens)
        return {
            "model": model,
            "criterion": criterion,
            "optimizer": optimizer,
            "data": data,
        }


def bench_single_process(args):
523
524
    num_devices = torch.cuda.device_count()
    assert num_devices > 0
Tom Birch's avatar
Tom Birch committed
525
    init_random_seed(0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
526
    device = torch.device("cuda")
Tom Birch's avatar
Tom Birch committed
527
528
529
530
531
532

    new_data = True

    blob = make_model_and_data(args, None, new_data=new_data)
    model = blob["model"]

533
    balance = generate_balance(min(num_devices, 4), len(model))
Tom Birch's avatar
Tom Birch committed
534
535
536
    p = pipe.Pipe(
        model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
    )
537
    del model
Tom Birch's avatar
Tom Birch committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    del blob["model"]

    if new_data:
        train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
    else:
        ntokens, train_data, val_data, test_data = blob["data"]
        benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)


def run_mp_worker(args, available_workers):
    new_data = True

    blob = make_model_and_data(args, None, new_data=new_data)
    model = blob["model"]

553
    balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
Tom Birch's avatar
Tom Birch committed
554
555
556
    p = pipe.Pipe(
        model,
        balance,
557
        style=Pipe.AsyncSchedule,
Tom Birch's avatar
Tom Birch committed
558
559
560
561
562
        chunks=args.chunks,
        worker_map=get_worker_map(),
        input_device=torch.cuda.current_device(),
        pipelined_backward=args.pipelined_backward,
        checkpoint=args.checkpoint,
563
        # loss_fn=blob["criterion"],
Tom Birch's avatar
Tom Birch committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    ).cuda()

    if args.all_at_once and p.pipeline:
        print(f"running all at once")
        p.pipeline.all_at_once = True

    if new_data:
        train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
    else:
        ntokens, train_data, val_data, test_data = blob["data"]
        benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)


def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


def bench_multi_process(args, all_at_once=False):
    if args.local_world_size != 0:
        world_size = args.local_world_size
    else:
        world_size = min(torch.cuda.device_count(), 2)
    mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)


best_device_map = {
    0: "mlx5_0:1",
    1: "mlx5_0:1",
    2: "mlx5_1:1",
    3: "mlx5_1:1",
    4: "mlx5_2:1",
    5: "mlx5_2:1",
    6: "mlx5_3:1",
    7: "mlx5_3:1",
}


def bench_mpi(args):
    guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
611
612
613
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
    os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]
Tom Birch's avatar
Tom Birch committed
614
615

    os.environ["MASTER_ADDR"] = args.host
616
    os.environ["MASTER_PORT"] = "10638"
Tom Birch's avatar
Tom Birch committed
617
618
619
    if args.socket_name:
        os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
        os.environ["TP_SOCKET_IFNAME"] = args.socket_name
620
621
622
623
624

    torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)

    os.environ["MASTER_ADDR"] = args.host
    os.environ["MASTER_PORT"] = "10639"
Tom Birch's avatar
Tom Birch committed
625
626
627
    init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
628
    torch.cuda.set_device(local_rank % torch.cuda.device_count())
Tom Birch's avatar
Tom Birch committed
629
630
631
632
633
634
635
636
637

    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
    )

638
639
640
641
642
643
    backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}

    if args.ddp_zero:
        initialize_model_parallel(1, 4, **backends)
    else:
        initialize_model_parallel(1, world_size, **backends)
Tom Birch's avatar
Tom Birch committed
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    init_random_seed(0)

    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()


parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
664
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
Tom Birch's avatar
Tom Birch committed
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
parser.add_argument(
    "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
    "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
    "--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
    "--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
parser.set_defaults(pipelined_backward=True)

if __name__ == "__main__":
    args = parser.parse_args()
    # bench_multi_process(args, all_at_once=True)
    if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
        print(f"Running benchmark with args: {args}")
        bench_single_process(args)
    else:
        if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
            print(f"Running benchmark with args: {args}")
        bench_mpi(args)