Commit 0f091a1d authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fastmoe project

parents
Pipeline #263 failed with stages
in 0 seconds
import sys
import math
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax, Projection
from log_uniform_sampler import LogUniformSampler, sample_logits
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe=False, moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m,:m] = torch.triu(mask[:m,:m])
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:,:,None,None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
x = x_padded[1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
return x
def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
activation = nn.Sequential(
nn.ReLU(),
nn.Dropout(dropout)
)
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
activation=activation)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = super().forward(self.layer_norm(inp))
core_out = self.dropout(core_out)
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = super().forward(inp)
core_out = self.dropout(core_out)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ModuleList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(Projection(d_proj, d_embed))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(Projection(d_proj, d_emb_i))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0].weight)
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i].weight)
emb_flat.index_copy_(0, indices_i, emb_i)
embed = emb_flat.view(*inp.size(), self.d_proj)
embed.mul_(self.emb_scale)
return embed
class MemTransformerLM(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1, moe=False, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
d_embed = d_model if d_embed is None else d_embed
self.d_embed = d_embed
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
self.drop = nn.Dropout(dropout)
self.n_layer = n_layer
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
self.attn_type = attn_type
self.layers = nn.ModuleList()
if attn_type == 0: # the default attention
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
self.layers.append(
RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
self.layers.append(
DecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
)
self.sample_softmax = sample_softmax
# use sampled softmax
if sample_softmax > 0:
self.out_layer = nn.Linear(d_model, n_token)
if tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = tie_weight
self.sampler = LogUniformSampler(n_token, sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
cutoffs, div_val=div_val)
if tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
if tie_projs:
for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i].weight = self.word_emb.emb_projs[0].weight
elif tie_proj and div_val != 1:
self.crit.out_projs[i].weight = self.word_emb.emb_projs[i].weight
self.same_length = same_length
self.clamp_len = clamp_len
self._create_params()
def backward_compatible(self):
self.sample_softmax = -1
def _create_params(self):
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self, x):
if self.mem_len > 0:
mems = []
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=x.dtype, device=x.device)
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())
return new_mems
def _forward(self, dec_inp, mems=None):
qlen, bsz = dec_inp.size()
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
def forward(self, data, target, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems(data)
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb,
self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.contiguous().view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
return [loss]
else:
return [loss] + new_mems
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='unit test')
parser.add_argument('--n_layer', type=int, default=4, help='')
parser.add_argument('--n_rel_layer', type=int, default=4, help='')
parser.add_argument('--n_head', type=int, default=2, help='')
parser.add_argument('--d_head', type=int, default=2, help='')
parser.add_argument('--d_model', type=int, default=200, help='')
parser.add_argument('--d_embed', type=int, default=200, help='')
parser.add_argument('--d_inner', type=int, default=200, help='')
parser.add_argument('--dropout', type=float, default=0.0, help='')
parser.add_argument('--cuda', action='store_true', help='')
parser.add_argument('--seed', type=int, default=1111, help='')
parser.add_argument('--multi_gpu', action='store_true', help='')
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
B = 4
tgt_len, mem_len, ext_len = 36, 36, 0
data_len = tgt_len * 20
args.n_token = 10000
import data_utils
data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
cutoffs = [args.n_token // 2]
tie_projs = [False] + [True] * len(cutoffs)
for div_val in [1, 2]:
for d_embed in [200, 100]:
model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
args.d_model, args.d_head, args.d_inner, args.dropout,
dropatt=args.dropout, tie_weight=True,
d_embed=d_embed, div_val=div_val,
tie_projs=tie_projs, pre_lnorm=True,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=cutoffs, attn_type=0).to(device)
print(sum(p.numel() for p in model.parameters()))
mems = tuple()
for idx, (inp, tgt, seqlen) in enumerate(diter):
print('batch {}'.format(idx))
out = model(inp, tgt, *mems)
mems = out[1:]
echo "=== Acquiring datasets ==="
echo "---"
mkdir -p ../data
cd ../data
if [[ ! -d 'wikitext-2' ]]; then
echo "- Downloading WikiText-2 (WT2)"
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
unzip -q wikitext-2-v1.zip
cd wikitext-2
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading WikiText-103 (WT2)"
if [[ ! -d 'wikitext-103' ]]; then
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
unzip -q wikitext-103-v1.zip
cd wikitext-103
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading enwik8 (Character)"
if [[ ! -d 'enwik8' ]]; then
mkdir -p enwik8
cd enwik8
wget --continue http://mattmahoney.net/dc/enwik8.zip
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
cd ..
fi
echo "- Downloading text8 (Character)"
if [[ ! -d 'text8' ]]; then
mkdir -p text8
cd text8
wget --continue http://mattmahoney.net/dc/text8.zip
python ../../prep_text8.py
cd ..
fi
echo "- Downloading Penn Treebank (PTB)"
if [[ ! -d 'penn' ]]; then
wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -xzf simple-examples.tgz
mkdir -p penn
cd penn
mv ../simple-examples/data/ptb.train.txt train.txt
mv ../simple-examples/data/ptb.test.txt test.txt
mv ../simple-examples/data/ptb.valid.txt valid.txt
cd ..
echo "- Downloading Penn Treebank (Character)"
mkdir -p pennchar
cd pennchar
mv ../simple-examples/data/ptb.char.train.txt train.txt
mv ../simple-examples/data/ptb.char.test.txt test.txt
mv ../simple-examples/data/ptb.char.valid.txt valid.txt
cd ..
rm -rf simple-examples/
fi
echo "- Downloading 1B words"
if [[ ! -d 'one-billion-words' ]]; then
mkdir -p one-billion-words
cd one-billion-words
wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
cat ${path}/news.en.heldout-00000-of-00050 > valid.txt
cat ${path}/news.en.heldout-00000-of-00050 > test.txt
wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
cd ..
fi
echo "---"
echo "Happy language modeling :)"
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 1024 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
--moe --moe-num-expert 64 --moe-top-k 2 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 24 \
--d_model 1024 \
--n_head 8 \
--d_head 128 \
--d_inner 3072 \
--dropout 0.15 \
--dropatt 0.15 \
--optim adam \
--lr 0.00025 \
--warmup_step 4000 \
--max_step 400000 \
--tgt_len 768 \
--mem_len 768 \
--eval_tgt_len 128 \
--batch_size 64 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 128 \
--mem_len 3800 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--adaptive \
--n_layer 18 \
--d_model 1024 \
--div_val 4 \
--n_head 8 \
--d_head 128 \
--d_inner 4096 \
--dropout 0.0 \
--dropatt 0.0 \
--optim adam \
--warmup_step 20000 \
--max_step 500000 \
--lr 0.00025 \
--tgt_len 32 \
--mem_len 32 \
--eval_tgt_len 32 \
--batch_size 224 \
--multi_gpu \
--gpu0_bsz 32 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--batch_size 64 \
--tgt_len 32 \
--mem_len 128 \
--split test \
--same_length \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--adaptive \
--div_val 4 \
--n_layer 24 \
--d_model 1280 \
--n_head 16 \
--d_head 80 \
--d_inner 8192 \
--dropout 0.05 \
--dropatt 0.05 \
--optim adam \
--warmup_step 30000 \
--max_step 1200000 \
--lr 0.00025 \
--tgt_len 32 \
--mem_len 32 \
--eval_tgt_len 32 \
--batch_size 512 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--batch_size 8 \
--tgt_len 32 \
--mem_len 128 \
--split test \
--same_length \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--n_layer 24 \
--d_model 1024 \
--n_head 8 \
--d_head 128 \
--d_inner 3072 \
--dropout 0.15 \
--dropatt 0.15 \
--optim adam \
--lr 0.00025 \
--tgt_len 768 \
--mem_len 768 \
--eval_tgt_len 128 \
--batch_size 64 \
--max_step 400000 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--tgt_len 128 \
--mem_len 3800 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--adaptive \
--n_layer 16 \
--d_model 410 \
--n_head 10 \
--d_head 41 \
--d_inner 2100 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 200000 \
--tgt_len 150 \
--mem_len 150 \
--eval_tgt_len 150 \
--batch_size 60 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--tgt_len 64 \
--mem_len 640 \
--clamp_len 400 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
export PYTHONPATH=$PWD/cuda/build/lib.linux-x86_64-3.7
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python3 train.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--adaptive \
--div_val 4 \
--n_layer 18 \
--d_model 1024 \
--n_head 16 \
--d_head 64 \
--d_inner 4096 \
--dropout 0.2 \
--dropatt 0.2 \
--optim adam \
--lr 0.00025 \
--warmup_step 16000 \
--max_step 4000000 \
--tgt_len 384 \
--mem_len 384 \
--eval_tgt_len 128 \
--batch_size 128 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--tgt_len 128 \
--mem_len 1600 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
# coding: utf-8
import argparse
import time
import math
import os, sys
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
parser.add_argument('--n_layer', type=int, default=12,
help='number of total layers')
parser.add_argument('--n_head', type=int, default=10,
help='number of heads')
parser.add_argument('--d_head', type=int, default=50,
help='head dimension')
parser.add_argument('--d_embed', type=int, default=-1,
help='embedding dimension')
parser.add_argument('--d_model', type=int, default=500,
help='model dimension')
parser.add_argument('--d_inner', type=int, default=1000,
help='inner dimension in FF')
parser.add_argument('--dropout', type=float, default=0.0,
help='global dropout rate')
parser.add_argument('--dropatt', type=float, default=0.0,
help='attention probability dropout rate')
parser.add_argument('--init', default='normal', type=str,
help='parameter initializer to use.')
parser.add_argument('--emb_init', default='normal', type=str,
help='parameter initializer to use.')
parser.add_argument('--init_range', type=float, default=0.1,
help='parameters initialized by U(-init_range, init_range)')
parser.add_argument('--emb_init_range', type=float, default=0.01,
help='parameters initialized by U(-init_range, init_range)')
parser.add_argument('--init_std', type=float, default=0.02,
help='parameters initialized by N(0, init_std)')
parser.add_argument('--proj_init_std', type=float, default=0.01,
help='parameters initialized by N(0, init_std)')
parser.add_argument('--optim', default='adam', type=str,
choices=['adam', 'sgd', 'adagrad'],
help='optimizer to use.')
parser.add_argument('--lr', type=float, default=0.00025,
help='initial learning rate (0.00025|5 for adam|sgd)')
parser.add_argument('--mom', type=float, default=0.0,
help='momentum for sgd')
parser.add_argument('--scheduler', default='cosine', type=str,
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
help='lr scheduler to use.')
parser.add_argument('--warmup_step', type=int, default=0,
help='upper epoch limit')
parser.add_argument('--decay_rate', type=float, default=0.5,
help='decay factor when ReduceLROnPlateau is used')
parser.add_argument('--lr_min', type=float, default=0.0,
help='minimum learning rate during annealing')
parser.add_argument('--clip', type=float, default=0.25,
help='gradient clipping')
parser.add_argument('--clip_nonemb', action='store_true',
help='only clip the gradient of non-embedding params')
parser.add_argument('--max_step', type=int, default=100000,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=60,
help='batch size')
parser.add_argument('--batch_chunk', type=int, default=1,
help='split batch into chunks to save memory')
parser.add_argument('--tgt_len', type=int, default=70,
help='number of tokens to predict')
parser.add_argument('--eval_tgt_len', type=int, default=50,
help='number of tokens to predict for evaluation')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
help='length of the retained previous heads')
parser.add_argument('--not_tied', action='store_true',
help='do not tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--adaptive', action='store_true',
help='use adaptive softmax')
parser.add_argument('--div_val', type=int, default=1,
help='divident value for adapative input and softmax')
parser.add_argument('--pre_lnorm', action='store_true',
help='apply LayerNorm to the input instead of the output')
parser.add_argument('--varlen', action='store_true',
help='use variable length')
parser.add_argument('--multi_gpu', action='store_true',
help='use multiple GPU')
parser.add_argument('--log-interval', type=int, default=200,
help='report interval')
parser.add_argument('--eval-interval', type=int, default=4000,
help='evaluation interval')
parser.add_argument('--work_dir', default='LM-TFM', type=str,
help='experiment directory.')
parser.add_argument('--restart', action='store_true',
help='restart training from the saved checkpoint')
parser.add_argument('--restart_dir', type=str, default='',
help='restart dir')
parser.add_argument('--debug', action='store_true',
help='run in debug mode (do not create exp dir)')
parser.add_argument('--same_length', action='store_true',
help='use the same attn length for all tokens')
parser.add_argument('--attn_type', type=int, default=0,
help='attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.')
parser.add_argument('--clamp_len', type=int, default=-1,
help='use the same pos embeddings after clamp_len')
parser.add_argument('--eta_min', type=float, default=0.0,
help='min learning rate for cosine scheduler')
parser.add_argument('--gpu0_bsz', type=int, default=-1,
help='batch size on gpu 0')
parser.add_argument('--max_eval_steps', type=int, default=-1,
help='max eval steps')
parser.add_argument('--sample_softmax', type=int, default=-1,
help='number of samples in sampled softmax')
parser.add_argument('--patience', type=int, default=0,
help='patience')
parser.add_argument('--finetune_v2', action='store_true',
help='finetune v2')
parser.add_argument('--finetune_v3', action='store_true',
help='finetune v3')
parser.add_argument('--fp16', action='store_true',
help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
parser.add_argument('--static-loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can '
'improve fp16 convergence.')
parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.')
parser.add_argument('--moe', action='store_true',
help='replace position-wise ffn with moe position-wise ffn')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top-k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args()
args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0:
args.d_embed = args.d_model
assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0
args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)
# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
else:
torch.cuda.manual_seed_all(args.seed)
# Validate `--fp16` option
if args.fp16:
if not args.cuda:
print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
args.fp16 = False
else:
try:
from apex.fp16_utils import FP16_Optimizer
except:
print('WARNING: apex not installed, ignoring --fp16 option')
args.fp16 = False
device = torch.device('cuda' if args.cuda else 'cpu')
###############################################################################
# Load data
###############################################################################
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
args.n_token = ntokens
eval_batch_size = 10
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
# adaptive softmax / embedding
cutoffs, tie_projs = [], [False]
if args.adaptive:
assert args.dataset in ['wt103', 'lm1b']
if args.dataset == 'wt103':
cutoffs = [20000, 40000, 200000]
tie_projs += [True] * len(cutoffs)
elif args.dataset == 'lm1b':
cutoffs = [60000, 100000, 640000]
tie_projs += [False] * len(cutoffs)
###############################################################################
# Build the model
###############################################################################
def init_weight(weight):
if args.init == 'uniform':
nn.init.uniform_(weight, -args.init_range, args.init_range)
elif args.init == 'normal':
nn.init.normal_(weight, 0.0, args.init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, args.init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
def update_dropout(m):
classname = m.__class__.__name__
if classname.find('Dropout') != -1:
if hasattr(m, 'p'):
m.p = args.dropout
def update_dropatt(m):
if hasattr(m, 'dropatt'):
m.dropatt.p = args.dropatt
if args.restart:
with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)
if not args.fp16:
model = model.float()
model.apply(update_dropout)
model.apply(update_dropatt)
else:
model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
args.d_head, args.d_inner, args.dropout, args.dropatt,
tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
if args.fp16:
model = model.half()
if args.multi_gpu:
model = model.to(device)
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
para_model = model.to(device)
#### optimizer
if args.optim.lower() == 'sgd':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.mom)
elif args.optim.lower() == 'adam':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
optimizer = optim.Adam(dense_params, lr=args.lr)
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.optim.lower() == 'adagrad':
optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
#### scheduler
if args.scheduler == 'cosine':
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
elif args.scheduler == 'constant':
pass
if args.cuda and args.fp16:
# If args.dynamic_loss_scale is False, static_loss_scale will be used.
# If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
optimizer = FP16_Optimizer(optimizer,
static_loss_scale = args.static_loss_scale,
dynamic_loss_scale = args.dynamic_loss_scale,
dynamic_loss_args = {'init_scale': 2 ** 16})
if args.restart:
if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
opt_state_dict = torch.load(f)
optimizer.load_state_dict(opt_state_dict)
else:
print('Optimizer was not saved. Start from scratch.')
logging('=' * 100)
for k, v in args.__dict__.items():
logging(' - {} : {}'.format(k, v))
logging('=' * 100)
logging('#params = {}'.format(args.n_all_param))
logging('#non emb params = {}'.format(args.n_nonemb_param))
###############################################################################
# Training code
###############################################################################
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
model.reset_length(args.eval_tgt_len,
args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
else:
model.reset_length(args.eval_tgt_len,
args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
# Evaluation
total_len, total_loss = 0, 0.
with torch.no_grad():
mems = tuple()
for i, (data, target, seq_len) in enumerate(eval_iter):
if args.max_eval_steps > 0 and i >= args.max_eval_steps:
break
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.float().item()
total_len += seq_len
# Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train()
return total_loss / total_len
def train():
# Turn on training mode which enables dropout.
global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
model.train()
if args.batch_chunk > 1:
mems = [tuple() for _ in range(args.batch_chunk)]
else:
mems = tuple()
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter):
model.zero_grad()
if args.batch_chunk > 1:
data_chunks = torch.chunk(data, args.batch_chunk, 1)
target_chunks = torch.chunk(target, args.batch_chunk, 1)
for i in range(args.batch_chunk):
data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
else:
ret = para_model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss)
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
if args.fp16:
optimizer.clip_master_grads(args.clip)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
if args.sample_softmax > 0:
optimizer_sparse.step()
# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if args.sample_softmax > 0:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step)
if args.sample_softmax > 0:
scheduler_sparse.step(train_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)
if train_step % args.log_interval == 0:
cur_loss = train_loss / args.log_interval
elapsed = time.time() - log_start_time
log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
'| ms/batch {:5.2f} | loss {:5.2f}'.format(
epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
elapsed * 1000 / args.log_interval, cur_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
logging(log_str)
train_loss = 0
log_start_time = time.time()
if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter)
logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
train_step // args.eval_interval, train_step,
(time.time() - eval_start_time), val_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
logging(log_str)
logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
if not args.debug:
with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
torch.save(model, f)
with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
torch.save(optimizer.state_dict(), f)
best_val_loss = val_loss
# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if args.sample_softmax > 0:
scheduler_sparse.step(val_loss)
eval_start_time = time.time()
if train_step == args.max_step:
break
# Loop over epochs.
train_step = 0
train_loss = 0
best_val_loss = None
log_start_time = time.time()
eval_start_time = time.time()
# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in itertools.count(start=1):
train()
if train_step == args.max_step:
logging('-' * 100)
logging('End of training')
break
except KeyboardInterrupt:
logging('-' * 100)
logging('Exiting from training early')
# Load the best saved model.
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)
para_model = model.to(device)
# Run on test data.
test_loss = evaluate(te_iter)
logging('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
test_loss, test_loss / math.log(2)))
else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss)))
logging('=' * 100)
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveLogSoftmax(nn.Module):
def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
super(AdaptiveLogSoftmax, self).__init__()
cutoffs = list(cutoffs)
if (cutoffs != sorted(cutoffs)) \
or (min(cutoffs) <= 0) \
or (max(cutoffs) >= (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \
or any([int(c) != c for c in cutoffs]):
raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1")
self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.keep_order = keep_order
def forward(self, hidden, target, weight, bias, keep_order=False):
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
head_weight = torch.cat(
[weight[:self.shortlist_size], self.cluster_weight], dim=0)
head_bias = torch.cat(
[bias[:self.shortlist_size], self.cluster_bias], dim=0)
head_logit = F.linear(hidden, head_weight, bias=head_bias)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < h_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
weight_i = weight[l_idx:h_idx]
bias_i = bias[l_idx:h_idx]
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return nll
from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
try:
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
except:
print('obj', obj.size())
print('dim', dim)
print('chunk_sizes', chunk_sizes)
quit()
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
r"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids)
if self.gpu0_bsz == 0:
replicas = replicas[1:]
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids)
def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
import functools
import os, shutil
import numpy as np
import torch
def logging(s, log_path, print_=True, log_=True):
if print_:
print(s)
if log_:
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
if debug:
print('Debug Mode : no experiment dir created')
return functools.partial(logging, log_path=None, log_=False)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print('Experiment dir : {}'.format(dir_path))
if scripts_to_save is not None:
script_path = os.path.join(dir_path, 'scripts')
if not os.path.exists(script_path):
os.makedirs(script_path)
for script in scripts_to_save:
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
def save_checkpoint(model, optimizer, path, epoch):
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
import torch
from torch import nn
import numpy as np
class LogUniformSampler(object):
def __init__(self, range_max, n_sample):
"""
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
expected count can be approximated by 1 - (1 - p)^n
and we use a numerically stable version -expm1(num_tries * log1p(-p))
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
"""
with torch.no_grad():
self.range_max = range_max
log_indices = torch.arange(1., range_max+2., 1.).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# print('P', self.dist.numpy().tolist()[-30:])
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.n_sample = n_sample
def sample(self, labels):
"""
labels: [b1, b2]
Return
true_log_probs: [b1, b2]
samp_log_probs: [n_sample]
neg_samples: [n_sample]
"""
# neg_samples = torch.empty(0).long()
n_sample = self.n_sample
n_tries = 2 * n_sample
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)
return true_log_probs, samp_log_probs, neg_samples
def sample_logits(embedding, bias, labels, inputs, sampler):
"""
embedding: an nn.Embedding layer
bias: [n_vocab]
labels: [b1, b2]
inputs: [b1, b2, n_emb]
sampler: you may use a LogUniformSampler
Return
logits: [b1, b2, 1 + n_sample]
"""
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
n_sample = neg_samples.size(0)
b1, b2 = labels.size(0), labels.size(1)
all_ids = torch.cat([labels.view(-1), neg_samples])
all_w = embedding(all_ids)
true_w = all_w[: -n_sample].view(b1, b2, -1)
sample_w = all_w[- n_sample:].view(n_sample, -1)
all_b = bias[all_ids]
true_b = all_b[: -n_sample].view(b1, b2)
sample_b = all_b[- n_sample:]
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum('ijk,ijk->ij',
[true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum('lk,ijk->ijl',
[sample_w, inputs]) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
return logits
# class LogUniformSampler(object):
# def __init__(self, range_max, unique=False):
# """
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
# """
# self.range_max = range_max
# log_indices = torch.arange(1., range_max+2., 1.).log_()
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# self.unique = unique
# if self.unique:
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
# def sample(self, n_sample, labels):
# pos_sample, new_labels = labels.unique(return_inverse=True)
# n_pos_sample = pos_sample.size(0)
# n_neg_sample = n_sample - n_pos_sample
# if self.unique:
# self.exclude_mask.index_fill_(0, pos_sample, 1)
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
# self.exclude_mask.index_fill_(0, pos_sample, 0)
# else:
# sample_dist = self.dist
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
# sample = torch.cat([pos_sample, neg_sample])
# sample_prob = self.dist[sample]
# return new_labels, sample, sample_prob
if __name__ == '__main__':
S, B = 3, 4
n_vocab = 10000
n_sample = 5
H = 32
labels = torch.LongTensor(S, B).random_(0, n_vocab)
# sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler = LogUniformSampler(n_vocab, unique=True)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist())
# print('samp_probs', samp_probs.numpy().tolist())
# print('neg_samples', neg_samples.numpy().tolist())
# print('sum', torch.sum(sampler.dist).item())
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
embedding = nn.Embedding(n_vocab, H)
bias = torch.zeros(n_vocab)
inputs = torch.Tensor(S, B, H).normal_()
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
print('logits', logits.detach().numpy().tolist())
print('logits shape', logits.size())
print('out_labels', out_labels.detach().numpy().tolist())
print('out_labels shape', out_labels.size())
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
CUDA_MINOR = int(torch.version.cuda.split('.')[1])
class Projection(nn.Module):
def __init__(self, out_feat, in_feat):
self.weight = nn.Parameter(torch.Tensor(out_feat, in_feat))
class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
super(ProjectedAdaptiveLogSoftmax, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = cutoffs + [n_token]
self.cutoff_ends = [0] + self.cutoffs
self.div_val = div_val
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
if self.n_clusters > 0:
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList()
self.out_projs = nn.ModuleList()
if div_val == 1:
for i in range(len(self.cutoffs)):
if d_proj != d_embed:
self.out_projs.append(
Projection(d_proj, d_embed)
)
else:
self.out_projs.append(None)
self.out_layers.append(nn.Linear(d_embed, n_token))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.out_projs.append(
Projection(d_proj, d_emb_i)
)
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
self.keep_order = keep_order
def _compute_logit(self, hidden, weight, bias, proj):
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
proj_hid = F.linear(hidden, proj.t().contiguous())
logit = F.linear(proj_hid, weight, bias=bias)
# else:
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
# if bias is not None:
# logit = logit + bias
return logit
def forward(self, hidden, target, keep_order=False):
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None)
nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
if self.div_val == 1:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers[0].weight[l_idx:r_idx]
bias_i = self.out_layers[0].bias[l_idx:r_idx]
else:
weight_i = self.out_layers[i].weight
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight if self.out_projs[0] is not None else None
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight if self.out_projs[i] is not None else None
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return nll
import os
from collections import Counter, OrderedDict
import torch
class Vocab(object):
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
self.max_size = max_size
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)
if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>']
elif add_eos:
return symbols + ['<eos>']
else:
return symbols
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
assert os.path.exists(path)
sents = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols)
sents.append(symbols)
return sents
def count_sents(self, sents, verbose=False):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
self.idx2sym = []
self.sym2idx = OrderedDict()
with open(vocab_file, 'r', encoding='utf-8') as f:
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
self.unk_idx = self.sym2idx['<UNK>']
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()
for sym in self.special:
self.add_special(sym)
for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break
self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def add_special(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
def add_symbol(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
assert '<eos>' not in sym
assert hasattr(self, 'unk_idx')
return self.sym2idx.get(sym, self.unk_idx)
def get_symbols(self, indices):
return [self.get_sym(idx) for idx in indices]
def get_indices(self, symbols):
return [self.get_idx(sym) for sym in symbols]
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.get_indices(symbols))
def convert_to_sent(self, indices, exclude=None):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
def __len__(self):
return len(self.idx2sym)
r"""
The fmoe package contains MoE Layers only.
"""
from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
import torch
import torch.nn.functional as F
metrics = {
"coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
"Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
"Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
}
def reset_balance_profile(balance_dict, num_layers, balance_strategy):
for key in metrics:
balance_dict[key] = [None for _ in range(num_layers)]
if balance_strategy:
balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)]
def update_balance_profile(
balance_dict,
gate_top_k_idx,
_gate_score_top_k,
gate_context,
layer_idx,
num_expert,
balance_strategy,
):
# Fill in this function to conduct balance related jobs
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment