Unverified Commit 29dd22e6 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Model] Support Multi-GPU for Transformer model (#356)

* multi-process version of transformer

* lots of fix

* fix bugs and accum gradients for multiple batches

* many fixes

* minor

* upd

* set torch device

* fix bugs

* fix and minor

* comments and clean up

* uncomment viz code
parent 1c63dd4f
......@@ -2,14 +2,14 @@ from .graph import *
from .fields import *
from .utils import prepare_dataset
import os
import numpy as np
import random
class ClassificationDataset:
class ClassificationDataset(object):
"Dataset class for classification task."
def __init__(self):
raise NotImplementedError
class TranslationDataset:
class TranslationDataset(object):
'''
Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language.
......@@ -22,17 +22,17 @@ class TranslationDataset:
vocab_path = os.path.join(path, vocab)
self.src = {}
self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['train'] = f.readlines()
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['train'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['valid'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['valid'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f:
with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['test'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f:
with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['test'] = f.readlines()
if not os.path.exists(vocab_path):
......@@ -90,20 +90,30 @@ class TranslationDataset:
def eos_id(self):
return self.vocab[self.EOS_TOKEN]
def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']):
def __call__(self, graph_pool, mode='train', batch_size=32, k=1,
device='cpu', dev_rank=0, ndev=1):
'''
Create a batched graph correspond to the mini-batch of the dataset.
args:
graph_pool: a GraphPool object for accelerating.
mode: train/valid/test
batch_size: batch size
devices: ['cpu'] or a list of gpu ids.
k: beam size(only required for test)
device: str or torch.device
dev_rank: rank (id) of current device
ndev: number of devices
'''
dev_id, gs = 0, []
src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data)
order = np.random.permutation(n) if mode == 'train' else range(n)
# make sure all devices have the same number of batch
n = n // ndev * ndev
# XXX: partition then shuffle may not be equivalent to shuffle then
# partition
order = list(range(dev_rank, n, ndev))
if mode == 'train':
random.shuffle(order)
src_buf, tgt_buf = [], []
for idx in order:
......@@ -115,22 +125,16 @@ class TranslationDataset:
tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size:
if mode == 'test':
assert len(devices) == 1 # we only allow single gpu for inference
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
dev_id += 1
if dev_id == len(devices):
yield gs if len(devices) > 1 else gs[0]
dev_id, gs = 0, []
yield graph_pool(src_buf, tgt_buf, device=device)
src_buf, tgt_buf = [], []
if len(src_buf) != 0:
if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
yield gs if len(devices) > 1 else gs[0]
yield graph_pool(src_buf, tgt_buf, device=device)
def get_sequence(self, batch):
"return a list of sequence from a list of index arrays"
......
......@@ -16,7 +16,7 @@ class Vocab:
self.vocab_lst.append(self.pad_token)
if self.unk_token is not None:
self.vocab_lst.append(self.unk_token)
with open(path, 'r') as f:
with open(path, 'r', encoding='utf-8') as f:
for token in f.readlines():
token = token.strip()
self.vocab_lst.append(token)
......
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
class LabelSmoothing(nn.Module):
"""
......@@ -37,14 +38,37 @@ class LabelSmoothing(nn.Module):
class SimpleLossCompute(nn.Module):
eps=1e-8
def __init__(self, criterion, opt=None):
"""
opt is required during training
def __init__(self, criterion, grad_accum, opt=None):
"""Loss function and optimizer for single device
Parameters
----------
criterion: torch.nn.Module
criterion to compute loss
grad_accum: int
number of batches to accumulate gradients
opt: Optimizer
Model optimizer to use. If None, then no backward and update will be
performed
"""
super(SimpleLossCompute, self).__init__()
self.criterion = criterion
self.opt = opt
self.reset()
self.acc_loss = 0
self.n_correct = 0
self.norm_term = 0
self.loss = 0
self.batch_count = 0
self.grad_accum = grad_accum
def __enter__(self):
self.batch_count = 0
def __exit__(self, type, value, traceback):
# if not enough batches accumulated and there are gradients not applied,
# do one more step
if self.batch_count > 0:
self.step()
@property
def avg_loss(self):
......@@ -54,32 +78,56 @@ class SimpleLossCompute(nn.Module):
def accuracy(self):
return (self.n_correct + self.eps) / (self.norm_term + self.eps)
def reset(self):
self.acc_loss = 0
self.n_correct = 0
self.norm_term = 0
def step(self):
self.opt.step()
self.opt.optimizer.zero_grad()
def backward_and_step(self):
self.loss.backward()
self.batch_count += 1
# accumulate self.grad_accum times then synchronize and update
if self.batch_count == self.grad_accum:
self.step()
self.batch_count = 0
def __call__(self, y_pred, y, norm):
y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1])
y = y.contiguous().view(-1)
loss = self.criterion(
y_pred, y
) / norm
self.loss = self.criterion(y_pred, y) / norm
if self.opt is not None:
loss.backward()
self.opt.step()
self.opt.optimizer.zero_grad()
self.backward_and_step()
self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item()
self.acc_loss += loss.item() * norm
self.acc_loss += self.loss.item() * norm
self.norm_term += norm
return loss.item() * norm
return self.loss.item() * norm
class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, devices, opt=None, chunk_size=5):
self.criterion = criterion
self.opt = opt
self.devices = devices
self.chunk_size = chunk_size
def __init__(self, criterion, ndev, grad_accum, model, opt=None):
"""Loss function and optimizer for multiple devices
Parameters
----------
criterion: torch.nn.Module
criterion to compute loss
ndev: int
number of devices used
grad_accum: int
number of batches to accumulate gradients
model: torch.nn.Module
model to optimizer (needed to iterate and synchronize all parameters)
opt: Optimizer
Model optimizer to use. If None, then no backward and update will be
performed
"""
super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt)
self.ndev = ndev
self.model = model
def __call__(self, y_preds, ys, norms):
pass
def step(self):
# multi-gpu synchronize gradients
for param in self.model.parameters():
if param.requires_grad and param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= self.ndev
self.opt.step()
self.opt.optimizer.zero_grad()
......@@ -12,7 +12,7 @@ class MultiHeadAttention(nn.Module):
self.h = h
# W_q, W_k, W_v, W_o
self.linears = clones(
nn.Linear(dim_model, dim_model), 4
nn.Linear(dim_model, dim_model, bias=False), 4
)
def get(self, x, fields='qkv'):
......
......@@ -46,10 +46,10 @@ class Decoder(nn.Module):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
if fields == 'kv':
norm_x = x # In enc-dec attention, x has already been normalized.
norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x
if fields != 'qkv':
return layer.src_attn.get(norm_x, fields)
else:
norm_x = layer.sublayer[l].norm(x)
return layer.self_attn.get(norm_x, fields)
return func
......@@ -64,8 +64,6 @@ class Decoder(nn.Module):
return {'x': x if i < self.N - 1 else self.norm(x)}
return func
lock = threading.Lock()
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
super(Transformer, self).__init__()
......@@ -124,9 +122,10 @@ class Transformer(nn.Module):
self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)])
# visualize attention
with lock:
"""
if self.att_weight_map is None:
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
"""
return self.generator(g.ndata['x'][nids['dec']])
......
# Mostly then same with PyTorch
import threading
import torch
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(modules, inputs):
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
try:
#with torch.cuda.device(device):
output = module(input)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input))
for i, (module, input) in
enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
"""
In current version we use multi30k as the default training and validation set.
Multi-GPU support is required to train the model on WMT14.
"""
from modules import *
from parallel import *
from loss import *
from optims import *
from dataset import *
from modules.config import *
from modules.viz import *
from tqdm import tqdm
#from modules.viz import *
import numpy as np
import argparse
import torch
from functools import partial
import torch.distributed as dist
def run_epoch(data_iter, model, loss_compute, is_train=True):
def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
universal = isinstance(model, UTransformer)
for i, g in tqdm(enumerate(data_iter)):
with loss_compute:
for i, g in enumerate(data_iter):
with T.set_grad_enabled(is_train):
if isinstance(model, list):
model = model[:len(gs)]
output = parallel_apply(model, g)
tgt_y = [g.tgt_y for g in gs]
n_tokens = [g.n_tokens for g in gs]
else:
if universal:
output, loss_act = model(g)
if is_train: loss_act.backward(retain_graph=True)
......@@ -36,64 +28,128 @@ def run_epoch(data_iter, model, loss_compute, is_train=True):
for step in range(1, model.MAX_DEPTH + 1):
print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0])))
model.reset_stat()
print('average loss: {}'.format(loss_compute.avg_loss))
print('accuracy: {}'.format(loss_compute.accuracy))
print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format(
epoch, "Training" if is_train else "Evaluating",
dev_rank, loss_compute.avg_loss, loss_compute.accuracy))
if __name__ == '__main__':
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model')
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
argparser.add_argument('--dataset', default='multi30k', help='dataset')
argparser.add_argument('--batch', default=128, type=int, help='batch size')
argparser.add_argument('--viz', action='store_true', help='visualize attention')
argparser.add_argument('--universal', action='store_true', help='use universal transformer')
args = argparser.parse_args()
args_filter = ['batch', 'gpus', 'viz']
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')]
def run(dev_id, args):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=args.master_ip, master_port=args.master_port)
world_size = args.ngpu
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
gpu_rank = torch.distributed.get_rank()
assert gpu_rank == dev_id
main(dev_id, args)
def main(dev_id, args):
if dev_id == -1:
device = torch.device('cpu')
else:
device = torch.device('cuda:{}'.format(dev_id))
# Set current device
th.cuda.set_device(device)
# Prepare dataset
dataset = get_dataset(args.dataset)
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 512
# Build graph pool
graph_pool = GraphPool()
model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal)
# Create model
model = make_model(V, V, N=args.N, dim_model=dim_model,
universal=args.universal)
# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight
# Move model to corresponding device
model, criterion = model.to(device), criterion.to(device)
# Loss function
if args.ngpu > 1:
dev_rank = dev_id # current device id
ndev = args.ngpu # number of devices (including cpu)
loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu,
args.grad_accum, model)
else: # cpu or single gpu case
dev_rank = 0
ndev = 1
loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum)
if ndev > 1:
for param in model.parameters():
dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
param.data /= ndev
model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
if len(devices) > 1:
model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices])
loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute
# Optimizer
model_opt = NoamOpt(dim_model, 1, 4000,
T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9))
# Train & evaluate
for epoch in range(100):
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices)
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices)
print('Epoch: {} Training...'.format(epoch))
start = time.time()
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch,
device=device, dev_rank=dev_rank, ndev=ndev)
model.train(True)
run_epoch(train_iter, model,
loss_compute(criterion, model_opt), is_train=True)
print('Epoch: {} Evaluating...'.format(epoch))
run_epoch(epoch, train_iter, dev_rank, ndev, model,
loss_compute(opt=model_opt), is_train=True)
if dev_rank == 0:
model.att_weight_map = None
model.eval()
run_epoch(valid_iter, model,
loss_compute(criterion, None), is_train=False)
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch,
device=device, dev_rank=dev_rank, ndev=1)
run_epoch(epoch, valid_iter, dev_rank, 1, model,
loss_compute(opt=None), is_train=False)
end = time.time()
print("epoch time: {}".format(end - start))
# Visualize attention
if args.viz:
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))
print('----------------------------------')
args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu']
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f:
th.save(model.state_dict(), f)
torch.save(model.state_dict(), f)
if __name__ == '__main__':
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model')
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
argparser.add_argument('--dataset', default='multi30k', help='dataset')
argparser.add_argument('--batch', default=128, type=int, help='batch size')
argparser.add_argument('--viz', action='store_true',
help='visualize attention')
argparser.add_argument('--universal', action='store_true',
help='use universal transformer')
argparser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
argparser.add_argument('--master-port', type=str, default='12345',
help='master port')
argparser.add_argument('--grad-accum', type=int, default=1,
help='accumulate gradients for this many times '
'then update weights')
args = argparser.parse_args()
print(args)
devices = list(map(int, args.gpus.split(',')))
if len(devices) == 1:
args.ngpu = 0 if devices[0] < 0 else 1
main(devices[0], args)
else:
args.ngpu = len(devices)
mp = torch.multiprocessing.get_context('spawn')
procs = []
for dev_id in devices:
procs.append(mp.Process(target=run, args=(dev_id, args),
daemon=True))
procs[-1].start()
for p in procs:
p.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment