Commit b0990e4b authored by Rick Ho's avatar Rick Ho
Browse files

Merge branch 'master' into laekov/accfix

parents 89de2153 1cfc5462
transformer-xl/data
transformer-xl/LM-TFM-enwik8
data
This diff is collapsed.
...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then ...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size 22 \ --batch_size 22 \
--multi_gpu \ --multi_gpu \
--gpu0_bsz 4 \ --gpu0_bsz 4 \
--moe --moe-num-expert 64 --moe-top-k 2 \
${@:2} ${@:2}
elif [[ $1 == 'eval' ]]; then elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...' echo 'Run evaluation...'
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import math import math
import os, sys import os, sys
import itertools import itertools
import pathlib
import numpy as np import numpy as np
...@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM ...@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel from utils.data_parallel import BalancedDataParallel
class AverageMeter(object):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103', parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus') help='location of the data corpus')
...@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1, ...@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true', parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument' help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.') ' supersedes --static-loss-scale.')
parser.add_argument('--moe', action='store_true',
help='replace position-wise ffn with moe position-wise ffn')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top-k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args() args = parser.parse_args()
args.tied = not args.not_tied args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0: if args.d_embed < 0:
args.d_embed = args.d_model args.d_embed = args.d_model
...@@ -305,7 +286,8 @@ else: ...@@ -305,7 +286,8 @@ else:
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type, same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init) model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_all_param = sum([p.nelement() for p in model.parameters()])
...@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param)) ...@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer. # If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same. # Otherwise, make the mem_len longer and keep the ext_len the same.
...@@ -434,33 +413,15 @@ def evaluate(eval_iter): ...@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break break
ret = model(data, target, *mems) ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.mean() loss = loss.mean()
total_loss += seq_len * loss.float().item() total_loss += seq_len * loss.float().item()
total_len += seq_len total_len += seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode # Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train() model.train()
return total_loss / total_len return total_loss / total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def train(): def train():
...@@ -471,11 +432,6 @@ def train(): ...@@ -471,11 +432,6 @@ def train():
mems = [tuple() for _ in range(args.batch_chunk)] mems = [tuple() for _ in range(args.batch_chunk)]
else: else:
mems = tuple() mems = tuple()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter): for batch, (data, target, seq_len) in enumerate(train_iter):
model.zero_grad() model.zero_grad()
...@@ -487,7 +443,6 @@ def train(): ...@@ -487,7 +443,6 @@ def train():
target_i = target_chunks[i].contiguous() target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i]) ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:] loss, mems[i] = ret[0], ret[1:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
...@@ -497,28 +452,12 @@ def train(): ...@@ -497,28 +452,12 @@ def train():
else: else:
ret = para_model(data, target, *mems) ret = para_model(data, target, *mems)
loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) loss = loss.float().mean().type_as(loss)
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
else: else:
loss.backward() loss.backward()
train_loss += loss.float().item() train_loss += loss.float().item()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if args.fp16: if args.fp16:
optimizer.clip_master_grads(args.clip) optimizer.clip_master_grads(args.clip)
...@@ -557,39 +496,12 @@ def train(): ...@@ -557,39 +496,12 @@ def train():
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else: else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging(log_str) logging(log_str)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss = 0 train_loss = 0
log_start_time = time.time() log_start_time = time.time()
if train_step % args.eval_interval == 0: if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter) val_loss = evaluate(va_iter)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging('-' * 100) logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format( '| valid loss {:5.2f}'.format(
...@@ -599,11 +511,6 @@ def train(): ...@@ -599,11 +511,6 @@ def train():
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else: else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging(log_str) logging(log_str)
logging('-' * 100) logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far. # Save the model if the validation loss is the best we've seen so far.
...@@ -653,7 +560,6 @@ para_model = model.to(device) ...@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data. # Run on test data.
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging('=' * 100) logging('=' * 100)
if args.dataset in ['enwik8', 'text8']: if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
...@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']: ...@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else: else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss))) test_loss, math.exp(test_loss)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging('=' * 100) logging('=' * 100)
...@@ -9,6 +9,10 @@ import torch.nn.functional as F ...@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
CUDA_MINOR = int(torch.version.cuda.split('.')[1]) CUDA_MINOR = int(torch.version.cuda.split('.')[1])
class Projection(nn.Module):
def __init__(self, out_feat, in_feat):
self.weight = nn.Parameter(torch.Tensor(out_feat, in_feat))
class ProjectedAdaptiveLogSoftmax(nn.Module): class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False): keep_order=False):
...@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList() self.out_layers = nn.ModuleList()
self.out_projs = nn.ParameterList() self.out_projs = nn.ModuleList()
if div_val == 1: if div_val == 1:
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
if d_proj != d_embed: if d_proj != d_embed:
self.out_projs.append( self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_embed)) Projection(d_proj, d_embed)
) )
else: else:
self.out_projs.append(None) self.out_projs.append(None)
...@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i = d_embed // (div_val ** i) d_emb_i = d_embed // (div_val ** i)
self.out_projs.append( self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_emb_i)) Projection(d_proj, d_emb_i)
) )
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
...@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if self.n_clusters == 0: if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0]) self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None)
nll = -F.log_softmax(logit, dim=-1) \ nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1) .gather(1, target.unsqueeze(1)).squeeze(1)
else: else:
...@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights.append(weight_i) weights.append(weight_i)
biases.append(bias_i) biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = F.log_softmax(head_logit, dim=1)
...@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if i == 0: if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else: else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight
hidden_i = hidden.index_select(0, indices_i) hidden_i = hidden.index_select(0, indices_i)
......
...@@ -36,14 +36,39 @@ class FMoELinear(nn.Module): ...@@ -36,14 +36,39 @@ class FMoELinear(nn.Module):
''' '''
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias is not None: if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(self.bias, bias = torch.repeat_interleave(self.bias,
fwd_expert_count.to(self.bias.device), dim=0) fwd_expert_count.to(self.bias.device), dim=0)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x = x + bias x = x + bias
return x return x
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \ return 'num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format( out_features={}, bias={}, rank={}'.format(
self.num_expert, self.in_feat, self.num_expert, self.in_feat,
self.out_feat, self.bias is not None, self.rank self.out_feat, self.bias is not None, self.rank
) )
......
...@@ -3,18 +3,17 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two ...@@ -3,18 +3,17 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
''' '''
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from .transformer import FMoETransformerMLP from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
class _MegatronMLP(nn.Module): class _FakeMegatronMLP(nn.Module):
r'''
A fake mlp without model parallelism for correctness testing
'''
def __init__(self, args, group): def __init__(self, args, group):
super().__init__() super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size) self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
......
...@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE): ...@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden=4096, d_hidden=4096,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.GELU(),
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
do_lnorm=False,
pre_lnorm=False,
add_residual=False,
expert_dp_comm='none' expert_dp_comm='none'
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group) top_k=top_k, world_size=world_size, mp_group=mp_group)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
if do_lnorm:
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
else:
self.pre_lnorm = None
self.add_residual = add_residual
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
...@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE): ...@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
''' '''
original_shape = inp.shape original_shape = inp.shape
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp) output = super().forward(inp)
if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output)
if self.add_residual:
output += inp
return output.reshape(original_shape) return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment