transformer.py 9.2 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import math
import time

import torch
import torch.nn as nn
import torchtext
from torchtext.data.utils import get_tokenizer

11
12
import fairscale.nn.pipe.pipe as pipe

Jun Ru Anderson's avatar
Jun Ru Anderson committed
13
14
15
16
17
18
19
20
21
try:
    from fairscale.optim.adam import Adam  # type: ignore

    can_benchmark = True
except ImportError:
    from torch.optim import Adam  # type: ignore

    can_benchmark = False

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

class EmbeddingLayer(nn.Embedding):
    def __init__(self, ntoken, ninp, initrange):
        super().__init__(ntoken, ninp)
        self.ninp = ninp
        self.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        return super().forward(src) * math.sqrt(self.ninp)


class PositionalEncodingLayer(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncodingLayer, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class TransformerDecoderLayer(nn.TransformerEncoderLayer):
    """Though this class inherits from torch.nn.TransformerEncoderLayer,
        it functions as a decoder in this model"""

    def __init__(self, ninp, nhead, nhid, droupout):
        super().__init__(ninp, nhead, nhid, droupout)
        self.src_mask = None

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        return super().forward(src, self.src_mask)


class LinearLayer(nn.Linear):
    def __init__(self, ninp, ntoken, initrange):
        super().__init__(ninp, ntoken)
        self.bias.data.zero_()
        self.weight.data.uniform_(-initrange, initrange)


class TransformerLMSequntial(nn.Sequential):
    """A small language model based on the design of GPT-2 using nn.Sequeitnal
       for compatability with Pipe"""

    def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange):
        super(TransformerLMSequntial, self).__init__(
            EmbeddingLayer(ntokens, ninp, initrange),
            PositionalEncodingLayer(ninp, dropout),
            TransformerDecoderLayer(ninp, nhead, nhid, dropout),
            LinearLayer(ninp, ntokens, initrange),
        )

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

def get_data(device):
    TEXT = torchtext.data.Field(
        tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
    )
    train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
    TEXT.build_vocab(train_txt)
    ntokens = len(TEXT.vocab.stoi)

    batch_size = 20
    eval_batch_size = 10
    train_data = batchify(train_txt, batch_size, TEXT, device)
    val_data = batchify(val_txt, eval_batch_size, TEXT, device)
    test_data = batchify(test_txt, eval_batch_size, TEXT, device)

    return ntokens, train_data, val_data, test_data


def batchify(data, bsz, TEXT, device):
    data = TEXT.numericalize([data.examples[0].text])
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def get_batch(source, i, bptt):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len].view(-1)
    return data, target


def make_model(device, ntokens):
126
    ninp = 50  # embedding dimension
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
127
128
    nhid = 50  # the dimension of the feedforward network model in nn.TransformerEncoder
    nhead = 2  # the number of heads in the multiheadattention models
129
130
131
132
133
    dropout = 0
    initrange = 0.1

    model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device)

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
134
    criterion = nn.CrossEntropyLoss()
135
136
    lr = 0.01  # learning rate
    optimizer = Adam(model.parameters(), lr=lr)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
137
138
139
140
141
142
143
144
145
146
147
148

    return model, criterion, optimizer


def train(train_data, model, criterion, optimizer, bptt, ntokens):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i, bptt)
        optimizer.zero_grad()
        output = model(data)
149
150
        output = output.to(targets.device)

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
151
152
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
153
154

        torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
155
156
157
158
159
160
161
162
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print(
163
164
165
                "| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
                "loss {:5.2f} | ppl {:8.2f}".format(
                    batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
166
167
168
169
170
171
172
173
174
175
176
177
178
                )
            )
            total_loss = 0
            start_time = time.time()


def evaluate(eval_model, data_source, criterion, bptt, ntokens):
    eval_model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            output = eval_model(data)
179
            output = output.to(targets.device)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def get_number_of_words(data):
    return data.size()[0] * data.size()[1]


def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens):
    epoch = 1
    bptt = 35
    start_time = time.time()

    print("-" * 89)
    print("| start of epoch {:1d}".format(epoch))
    print("-" * 89)
    epoch_start_time = time.time()
    train(train_data, model, criterion, optimizer, bptt, ntokens)
    val_loss = evaluate(model, val_data, criterion, bptt, ntokens)
    print("-" * 89)
    print(
202
203
204
        "| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
            epoch, (time.time() - epoch_start_time), val_loss
        )
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
205
206
207
208
209
210
211
212
213
214
    )
    print("-" * 89)

    elapsed_time = time.time() - start_time
    nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
    wps = nwords / elapsed_time

    test_loss = evaluate(model, test_data, criterion, bptt, ntokens)
    print("=" * 89)
    print(
215
216
        "| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
            test_loss, elapsed_time, nwords, wps
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
217
218
219
220
        )
    )
    print("=" * 89)

Jun Ru Anderson's avatar
Jun Ru Anderson committed
221
    if can_benchmark and len(model.balance) == 4:
222
        # Assert that words per second is within 3 standard deviations of the average
Jun Ru Anderson's avatar
Jun Ru Anderson committed
223
224
        # of six golden runs
        assert wps > 20052.1 - (3 * 359)
225
226
227
228
229
230
231

        print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:2: {:1d}".format(torch.cuda.memory_stats(2)["allocated_bytes.all.peak"]))
        print("Peak allocated bytes on cuda:3: {:1d}".format(torch.cuda.memory_stats(3)["allocated_bytes.all.peak"]))

        # Assert that memory usage on each GPU is within 10% of golden run
232
        # Right-hand-side is golden run bytes * 110%
Jun Ru Anderson's avatar
Jun Ru Anderson committed
233
        assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365916160 * 1.1
234
235
236
        assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1
        assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1
        assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        print("No regression detected")


def generate_balance(num_devices, num_layers):
    balance = []
    layers_assigned = 0
    for i in range(num_devices):
        x = (num_layers - layers_assigned) / (num_devices - i)
        if x.is_integer():
            balance.append(int(x))
            layers_assigned += x
        else:
            balance.append(math.ceil(x))
            layers_assigned += math.ceil(x)
    return balance

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
253
254

if __name__ == "__main__":
255
256
257
258
    num_devices = torch.cuda.device_count()
    assert num_devices > 0

    torch.manual_seed(0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
259
260
261
    device = torch.device("cuda")
    ntokens, train_data, val_data, test_data = get_data(device)
    model, criterion, optimizer = make_model(device, ntokens)
262
263
264
265
    balance = generate_balance(min(num_devices, 4), len(model))
    p = pipe.Pipe(model, balance)
    benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens)
    del p