Commit 74181b08 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by Mandeep Singh Baines
Browse files

[feat] add Transformer gpipe benchmark

parent 0cd65242
[settings]
known_third_party =models,pytest,setuptools,torch,torchtext
known_third_party =pytest,setuptools,torch,torchtext
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp=200, nhead=2, nhid=200, nlayers=2, dropout=0.5):
super(TransformerModel, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = "Transformer"
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
......@@ -3,12 +3,83 @@
import math
import time
from models import transformerModel as transformer
import torch
import torch.nn as nn
import torchtext
from torchtext.data.utils import get_tokenizer
import fairscale.nn.pipe.pipe as pipe
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
self.ninp = ninp
self.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
return super().forward(src) * math.sqrt(self.ninp)
class PositionalEncodingLayer(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncodingLayer, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
def __init__(self, ninp, nhead, nhid, droupout):
super().__init__(ninp, nhead, nhid, droupout)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
return super().forward(src, self.src_mask)
class LinearLayer(nn.Linear):
def __init__(self, ninp, ntoken, initrange):
super().__init__(ninp, ntoken)
self.bias.data.zero_()
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange):
super(TransformerLMSequntial, self).__init__(
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
TransformerDecoderLayer(ninp, nhead, nhid, dropout),
LinearLayer(ninp, ntokens, initrange),
)
def get_data(device):
TEXT = torchtext.data.Field(
......@@ -43,14 +114,16 @@ def get_batch(source, i, bptt):
def make_model(device, ntokens):
emsize = 50 # embedding dimension
ninp = 50 # embedding dimension
nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = transformer.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
dropout = 0
initrange = 0.1
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device)
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
lr = 1.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
return model, criterion, optimizer
......@@ -64,9 +137,12 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
data, targets = get_batch(train_data, i, bptt)
optimizer.zero_grad()
output = model(data)
output = output.to(targets.device)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()
total_loss += loss.item()
......@@ -75,8 +151,9 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss),
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
)
)
total_loss = 0
......@@ -90,6 +167,7 @@ def evaluate(eval_model, data_source, criterion, bptt, ntokens):
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i, bptt)
output = eval_model(data)
output = output.to(targets.device)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
......@@ -112,8 +190,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
val_loss = evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89)
print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} | "
"valid ppl {:8.2f}".format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss
)
)
print("-" * 89)
......@@ -124,16 +203,54 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
test_loss = evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89)
print(
"| end of training | test loss {:5.2f} | test ppl {:8.2f}\n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, math.exp(test_loss), elapsed_time, nwords, wps
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps
)
)
print("=" * 89)
if len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
assert wps > 19779.5 - (3 * 167.81)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:2: {:1d}".format(torch.cuda.memory_stats(2)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:3: {:1d}".format(torch.cuda.memory_stats(3)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run KB * KB to bytes conversion * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 346094 * 1024 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1251 * 1024 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2595 * 1024 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 174784 * 1024 * 1.1
print("No regression detected")
def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
for i in range(num_devices):
x = (num_layers - layers_assigned) / (num_devices - i)
if x.is_integer():
balance.append(int(x))
layers_assigned += x
else:
balance.append(math.ceil(x))
layers_assigned += math.ceil(x)
return balance
if __name__ == "__main__":
assert torch.cuda.is_available()
num_devices = torch.cuda.device_count()
assert num_devices > 0
torch.manual_seed(0)
device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer = make_model(device, ntokens)
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens)
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(model, balance)
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens)
del p
......@@ -47,7 +47,7 @@ use_parentheses=True
skip_glob = build/*,stubs/*
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = models,pytest,setuptools,torch,torchtext
known_third_party = pytest,setuptools,torch,torchtext
# -----------------------------------------------------------------------------
# mypy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment