Commit 74181b08 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by Mandeep Singh Baines
Browse files

[feat] add Transformer gpipe benchmark

parent 0cd65242
[settings] [settings]
known_third_party =models,pytest,setuptools,torch,torchtext known_third_party =pytest,setuptools,torch,torchtext
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp=200, nhead=2, nhid=200, nlayers=2, dropout=0.5):
super(TransformerModel, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = "Transformer"
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
...@@ -3,12 +3,83 @@ ...@@ -3,12 +3,83 @@
import math import math
import time import time
from models import transformerModel as transformer
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
import fairscale.nn.pipe.pipe as pipe
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
self.ninp = ninp
self.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
return super().forward(src) * math.sqrt(self.ninp)
class PositionalEncodingLayer(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncodingLayer, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
def __init__(self, ninp, nhead, nhid, droupout):
super().__init__(ninp, nhead, nhid, droupout)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
return super().forward(src, self.src_mask)
class LinearLayer(nn.Linear):
def __init__(self, ninp, ntoken, initrange):
super().__init__(ninp, ntoken)
self.bias.data.zero_()
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange):
super(TransformerLMSequntial, self).__init__(
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
TransformerDecoderLayer(ninp, nhead, nhid, dropout),
LinearLayer(ninp, ntokens, initrange),
)
def get_data(device): def get_data(device):
TEXT = torchtext.data.Field( TEXT = torchtext.data.Field(
...@@ -43,14 +114,16 @@ def get_batch(source, i, bptt): ...@@ -43,14 +114,16 @@ def get_batch(source, i, bptt):
def make_model(device, ntokens): def make_model(device, ntokens):
emsize = 50 # embedding dimension ninp = 50 # embedding dimension
nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value dropout = 0
model = transformer.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device) initrange = 0.1
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate lr = 1.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = torch.optim.SGD(model.parameters(), lr=lr)
return model, criterion, optimizer return model, criterion, optimizer
...@@ -64,9 +137,12 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens): ...@@ -64,9 +137,12 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
data, targets = get_batch(train_data, i, bptt) data, targets = get_batch(train_data, i, bptt)
optimizer.zero_grad() optimizer.zero_grad()
output = model(data) output = model(data)
output = output.to(targets.device)
loss = criterion(output.view(-1, ntokens), targets) loss = criterion(output.view(-1, ntokens), targets)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
...@@ -75,8 +151,9 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens): ...@@ -75,8 +151,9 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
cur_loss = total_loss / log_interval cur_loss = total_loss / log_interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
print( print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( "| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss), "loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
) )
) )
total_loss = 0 total_loss = 0
...@@ -90,6 +167,7 @@ def evaluate(eval_model, data_source, criterion, bptt, ntokens): ...@@ -90,6 +167,7 @@ def evaluate(eval_model, data_source, criterion, bptt, ntokens):
for i in range(0, data_source.size(0) - 1, bptt): for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i, bptt) data, targets = get_batch(data_source, i, bptt)
output = eval_model(data) output = eval_model(data)
output = output.to(targets.device)
output_flat = output.view(-1, ntokens) output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item() total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1) return total_loss / (len(data_source) - 1)
...@@ -112,8 +190,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -112,8 +190,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
val_loss = evaluate(model, val_data, criterion, bptt, ntokens) val_loss = evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89) print("-" * 89)
print( print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} | " "| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
"valid ppl {:8.2f}".format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)) epoch, (time.time() - epoch_start_time), val_loss
)
) )
print("-" * 89) print("-" * 89)
...@@ -124,16 +203,54 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -124,16 +203,54 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
test_loss = evaluate(model, test_data, criterion, bptt, ntokens) test_loss = evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89) print("=" * 89)
print( print(
"| end of training | test loss {:5.2f} | test ppl {:8.2f}\n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format( "| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, math.exp(test_loss), elapsed_time, nwords, wps test_loss, elapsed_time, nwords, wps
) )
) )
print("=" * 89) print("=" * 89)
if len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
assert wps > 19779.5 - (3 * 167.81)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:2: {:1d}".format(torch.cuda.memory_stats(2)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:3: {:1d}".format(torch.cuda.memory_stats(3)["allocated_bytes.all.peak"]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run KB * KB to bytes conversion * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 346094 * 1024 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1251 * 1024 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2595 * 1024 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 174784 * 1024 * 1.1
print("No regression detected")
def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
for i in range(num_devices):
x = (num_layers - layers_assigned) / (num_devices - i)
if x.is_integer():
balance.append(int(x))
layers_assigned += x
else:
balance.append(math.ceil(x))
layers_assigned += math.ceil(x)
return balance
if __name__ == "__main__": if __name__ == "__main__":
assert torch.cuda.is_available() num_devices = torch.cuda.device_count()
assert num_devices > 0
torch.manual_seed(0)
device = torch.device("cuda") device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device) ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer = make_model(device, ntokens) model, criterion, optimizer = make_model(device, ntokens)
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens) balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(model, balance)
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens)
del p
...@@ -47,7 +47,7 @@ use_parentheses=True ...@@ -47,7 +47,7 @@ use_parentheses=True
skip_glob = build/*,stubs/* skip_glob = build/*,stubs/*
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = models,pytest,setuptools,torch,torchtext known_third_party = pytest,setuptools,torch,torchtext
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# mypy # mypy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment