Unverified Commit 03b2a725 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #6 from xfmr-xl

Test Transformer-XL
parents e86dea53 0a942e3f
...@@ -10,3 +10,6 @@ a.out ...@@ -10,3 +10,6 @@ a.out
build build
*swp *swp
logs logs
examples/transformer-xl/data
examples/data
examples/transformer-xl/LM-TFM-enwik8
...@@ -7,10 +7,9 @@ import numpy as np ...@@ -7,10 +7,9 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# import torch_sparse
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax, Projection
from log_uniform_sampler import LogUniformSampler, sample_logits from log_uniform_sampler import LogUniformSampler, sample_logits
class PositionalEmbedding(nn.Module): class PositionalEmbedding(nn.Module):
...@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module): ...@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
return pos_emb[:,None,:] return pos_emb[:,None,:]
# A baseline naive slow implementation
class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFFRaw, self).__init__()
print("MoEPositionwiseFF")
self.top_k = top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, d_inner)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_k / d_inner
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
relu_out = F.relu(gate_top_k_val)
x = self.dropout_middle(relu_out)
W2_select = self.W2[gate_top_k_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ijk,ijkd->ijd', (x, W2_select)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
def my_topk(x, k, inplace=True):
y = x if inplace else x.clone()
top1_val, top1_idx = torch.max(y, dim=-1)
top1_val = top1_val.unsqueeze(-1)
top1_idx = top1_idx.unsqueeze(-1)
if k == 1:
return top1_val, top1_idx
y.scatter_(-1, top1_idx, value=float('-inf'))
top2_val, top2_idx = torch.max(y, dim=-1)
top2_val = top2_val.unsqueeze(-1)
top2_idx = top2_idx.unsqueeze(-1)
top_val = torch.cat((top1_val, top2_val), dim=-1)
top_idx = torch.cat((top1_idx, top2_idx), dim=-1)
return top_val, top_idx
class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(MultiHeadHierarchicalMoEPositionwiseFF, self).__init__()
print("MultiHeadHierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net_W = nn.Parameter(torch.Tensor(d_model, top_block, n_block))
self.block_net_b = nn.Parameter(torch.Tensor(top_block, n_block))
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
# self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
for i in range(self.top_block):
temp = nn.Linear(self.d_model, self.n_block)
self.block_net_W.data[:, i] = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b.data[i] = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ]
block_val, block_idx = my_topk(block, k=1, inplace=True)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val = block_val.squeeze(-1)
block_idx = block_idx.squeeze(-1)
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net = nn.Linear(d_model, n_block, bias=True)
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
# self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = self.block_net(inp)
# block_val, block_idx = my_topk(block, k=self.top_block)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF")
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
super(MultiHeadPositionwiseFF, self).__init__()
print("MultiHeadPositionwiseFF")
assert d_model % n_head == 0
self.n_head = n_head
d_head = d_model // n_head
self.d_head = d_head
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.q_net = nn.Linear(d_model, d_model)
self.k_weight = nn.Parameter(torch.Tensor(n_head, d_inner, d_head))
self.k_bias = nn.Parameter(torch.Tensor(n_head, d_inner))
self.v_weight = nn.Parameter(torch.Tensor(n_head, d_head, d_inner))
self.v_bias = nn.Parameter(torch.Tensor(n_head, d_head))
#self.o_net = nn.Linear(d_model, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
for i in range(self.n_head):
tmp = nn.Linear(self.d_head, self.d_inner)
self.k_weight.data[i] = tmp.weight.data
self.k_bias.data[i] = tmp.bias.data
tmp = nn.Linear(self.d_inner, self.d_head)
self.v_weight.data[i] = tmp.weight.data
self.v_bias.data[i] = tmp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
head_q = self.q_net(inp)
head_q = head_q.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head]
attn_score = torch.einsum('ibnd,nhd->ibnh', (head_q, self.k_weight)) + self.k_bias # [.. x n_head x d_inner]
attn_score = F.relu(attn_score)
attn_score = self.dropout(attn_score)
attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias
attn_vec = attn_vec.contiguous().view(inp.size(0), inp.size(1), self.d_model)
# core_out = self.o_net(attn_vec)
core_out = self.dropout(attn_vec)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, use_softmax=True): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__() super(PositionwiseFF, self).__init__()
self.d_model = d_model self.d_model = d_model
self.d_inner = d_inner self.d_inner = d_inner
self.dropout = dropout self.dropout = dropout
self.CoreNet_1 = nn.Sequential( self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Softmax(dim=-1) if use_softmax else nn.ReLU(inplace=True)
)
self.CoreNet_2 = nn.Sequential(
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(d_inner, d_model), nn.Linear(d_inner, d_model),
nn.Dropout(dropout), nn.Dropout(dropout),
...@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module): ...@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
def forward(self, inp): def forward(self, inp):
if self.pre_lnorm: if self.pre_lnorm:
##### layer normalization + positionwise feed-forward ##### layer normalization + positionwise feed-forward
relu_out = self.CoreNet_1(self.layer_norm(inp)) core_out = self.CoreNet(self.layer_norm(inp))
core_out = self.CoreNet_2(relu_out)
##### residual connection ##### residual connection
output = core_out + inp output = core_out + inp
else: else:
##### positionwise feed-forward ##### positionwise feed-forward
relu_out = self.CoreNet_1(inp) core_out = self.CoreNet(inp)
core_out = self.CoreNet_2(relu_out)
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(inp + core_out) output = self.layer_norm(inp + core_out)
return output return output
# return output, relu_out.detach()
class ExtendedMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(ExtendedMultiHeadAttn, self).__init__()
print("ExtendedMultiHeadAttn")
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head * 2, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
mem_len = mems.size(0)
else:
c = h
mem_len = 0
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(c)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3:
attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
mem2other_attn = attn_mask.new_ones(mem_len, c.size(0))
mem2other_attn[:, :mem_len] = 0
attn_score[:mem_len].masked_fill_(mem2other_attn[:, :, None, None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec_quad = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, attn_vec))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs = torch.cat([attn_vec.unsqueeze(-1), attn_vec_quad.unsqueeze(-1)], dim=-1)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec = attn_vecs.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head * 2)
attn_vec = attn_vec[mem_len:]
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class MultiHeadAttn(nn.Module): class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
...@@ -583,7 +142,8 @@ class MultiHeadAttn(nn.Module): ...@@ -583,7 +142,8 @@ class MultiHeadAttn(nn.Module):
class RelMultiHeadAttn(nn.Module): class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
moe=False, moe_num_expert=64, moe_top_k=2):
super(RelMultiHeadAttn, self).__init__() super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -816,42 +376,41 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -816,42 +376,41 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return output return output
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP): class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
def activation(x): activation = nn.Sequential(
return self.dropout(F.relu(x)) nn.Dropout(dropout),
super().__init__(num_expert=8, d_model=d_model, d_hidden=d_inner, nn.ReLU()
pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter(
torch.zeros(d_model, dtype=torch.float32)
) )
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation, dropout=dropout)
def forward(self, x): def forward(self, x):
x = super().forward(x) x = super().forward(x)
return x + self.bias return x
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) if kwargs.get('moe') is False:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output return output
# return output, relu_out
class RelLearnableDecoderLayer(nn.Module): class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -860,8 +419,15 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -860,8 +419,15 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -869,10 +435,8 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -869,10 +435,8 @@ class RelLearnableDecoderLayer(nn.Module):
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output return output
# return output, relu_out
class RelPartialLearnableDecoderLayer(nn.Module): class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -881,8 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -881,8 +445,15 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) if kwargs.get('moe') is False:
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
else:
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
moe_top_k=kwargs.get('moe_top_k'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
...@@ -890,10 +461,8 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -890,10 +461,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output return output
# return output, relu_out
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
...@@ -913,25 +482,26 @@ class AdaptiveEmbedding(nn.Module): ...@@ -913,25 +482,26 @@ class AdaptiveEmbedding(nn.Module):
self.cutoff_ends = [0] + self.cutoffs self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList() self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList() self.emb_projs = nn.ModuleList()
if div_val == 1: if div_val == 1:
self.emb_layers.append( self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
) )
if d_proj != d_embed: if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) self.emb_projs.append(Projection(d_proj, d_embed))
else: else:
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i) d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) self.emb_projs.append(Projectio(d_proj, d_emb_i))
def forward(self, inp): def forward(self, inp):
if self.div_val == 1: if self.div_val == 1:
embed = self.emb_layers[0](inp) embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed: if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0]) embed = F.linear(embed, self.emb_projs[0].weight)
else: else:
param = next(self.parameters()) param = next(self.parameters())
inp_flat = inp.view(-1) inp_flat = inp.view(-1)
...@@ -948,7 +518,7 @@ class AdaptiveEmbedding(nn.Module): ...@@ -948,7 +518,7 @@ class AdaptiveEmbedding(nn.Module):
inp_i = inp_flat.index_select(0, indices_i) - l_idx inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i) emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i]) emb_i = F.linear(emb_i, self.emb_projs[i].weight)
emb_flat.index_copy_(0, indices_i, emb_i) emb_flat.index_copy_(0, indices_i, emb_i)
...@@ -965,7 +535,7 @@ class MemTransformerLM(nn.Module): ...@@ -965,7 +535,7 @@ class MemTransformerLM(nn.Module):
tgt_len=None, ext_len=None, mem_len=None, tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False, cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1, same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1): sample_softmax=-1, moe=False, moe_num_expert=64, moe_top_k=2):
super(MemTransformerLM, self).__init__() super(MemTransformerLM, self).__init__()
self.n_token = n_token self.n_token = n_token
...@@ -996,7 +566,8 @@ class MemTransformerLM(nn.Module): ...@@ -996,7 +566,8 @@ class MemTransformerLM(nn.Module):
RelPartialLearnableDecoderLayer( RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type == 1: # learnable embeddings elif attn_type == 1: # learnable embeddings
for i in range(n_layer): for i in range(n_layer):
...@@ -1004,14 +575,16 @@ class MemTransformerLM(nn.Module): ...@@ -1004,14 +575,16 @@ class MemTransformerLM(nn.Module):
RelLearnableDecoderLayer( RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
elif attn_type in [2, 3]: # absolute embeddings elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer): for i in range(n_layer):
self.layers.append( self.layers.append(
DecoderLayer( DecoderLayer(
n_head, d_model, d_head, d_inner, dropout, n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm) dropatt=dropatt, pre_lnorm=pre_lnorm,
moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
) )
self.sample_softmax = sample_softmax self.sample_softmax = sample_softmax
...@@ -1035,9 +608,9 @@ class MemTransformerLM(nn.Module): ...@@ -1035,9 +608,9 @@ class MemTransformerLM(nn.Module):
if tie_projs: if tie_projs:
for i, tie_proj in enumerate(tie_projs): for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed: if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0] self.crit.out_projs[i].weight = self.word_emb.emb_projs[0].weight
elif tie_proj and div_val != 1: elif tie_proj and div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.crit.out_projs[i].weight = self.word_emb.emb_projs[i].weight
self.same_length = same_length self.same_length = same_length
self.clamp_len = clamp_len self.clamp_len = clamp_len
...@@ -1070,12 +643,11 @@ class MemTransformerLM(nn.Module): ...@@ -1070,12 +643,11 @@ class MemTransformerLM(nn.Module):
self.mem_len = mem_len self.mem_len = mem_len
self.ext_len = ext_len self.ext_len = ext_len
def init_mems(self): def init_mems(self, x):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
param = next(self.parameters())
for i in range(self.n_layer+1): for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device) empty = torch.empty(0, dtype=x.dtype, device=x.device)
mems.append(empty) mems.append(empty)
return mems return mems
...@@ -1126,7 +698,6 @@ class MemTransformerLM(nn.Module): ...@@ -1126,7 +698,6 @@ class MemTransformerLM(nn.Module):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = [] hids = []
# relu_outs = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -1140,11 +711,9 @@ class MemTransformerLM(nn.Module): ...@@ -1140,11 +711,9 @@ class MemTransformerLM(nn.Module):
hids.append(core_out) hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out = layer(core_out, pos_emb, self.r_w_bias, core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out) hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 1: # learnable elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
hids.append(core_out) hids.append(core_out)
...@@ -1156,11 +725,9 @@ class MemTransformerLM(nn.Module): ...@@ -1156,11 +725,9 @@ class MemTransformerLM(nn.Module):
r_emb, r_bias = self.r_emb[i], self.r_bias[i] r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out = layer(core_out, r_emb, self.r_w_bias[i], core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out) hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 2: # absolute elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -1175,11 +742,9 @@ class MemTransformerLM(nn.Module): ...@@ -1175,11 +742,9 @@ class MemTransformerLM(nn.Module):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0: if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen] mems_i += pos_emb[:mlen]
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out) hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 3: elif self.attn_type == 3:
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
...@@ -1197,31 +762,25 @@ class MemTransformerLM(nn.Module): ...@@ -1197,31 +762,25 @@ class MemTransformerLM(nn.Module):
mems_i += cur_emb.view(mlen, 1, -1) mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out) hids.append(core_out)
# relu_outs.append(relu_out)
core_out = self.drop(core_out) core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen) new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems return core_out, new_mems
# return core_out, new_mems, relu_outs
def forward(self, data, target, *mems): def forward(self, data, target, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted. # nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward. # So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece # Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together. # them together.
if not mems: mems = self.init_mems() if not mems: mems = self.init_mems(data)
tgt_len = target.size(0) tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems) hidden, new_mems = self._forward(data, mems=mems)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid = hidden[-tgt_len:] pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training: if self.sample_softmax > 0 and self.training:
...@@ -1235,10 +794,8 @@ class MemTransformerLM(nn.Module): ...@@ -1235,10 +794,8 @@ class MemTransformerLM(nn.Module):
if new_mems is None: if new_mems is None:
return [loss] return [loss]
# return [relu_outs, loss]
else: else:
return [loss] + new_mems return [loss] + new_mems
# return [relu_outs, loss] + new_mems
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
......
...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then ...@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size 22 \ --batch_size 22 \
--multi_gpu \ --multi_gpu \
--gpu0_bsz 4 \ --gpu0_bsz 4 \
--moe --moe-num-expert 64 --moe-top-k 2 \
${@:2} ${@:2}
elif [[ $1 == 'eval' ]]; then elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...' echo 'Run evaluation...'
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import math import math
import os, sys import os, sys
import itertools import itertools
import pathlib
import numpy as np import numpy as np
...@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM ...@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel from utils.data_parallel import BalancedDataParallel
class AverageMeter(object):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103', parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus') help='location of the data corpus')
...@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1, ...@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser.add_argument('--dynamic-loss-scale', action='store_true', parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument' help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.') ' supersedes --static-loss-scale.')
parser.add_argument('--moe', action='store_true',
help='replace position-wise ffn with moe position-wise ffn')
parser.add_argument('--moe-num-expert', type=int, default=64,
help='number of experts in MoE')
parser.add_argument('--moe-top-k', type=int, default=2,
help='top_k experts in hard gate of moe')
args = parser.parse_args() args = parser.parse_args()
args.tied = not args.not_tied args.tied = not args.not_tied
assert args.moe_num_expert >= args.moe_top_k, "must have moe-num-expert >= moe-top_k"
if args.d_embed < 0: if args.d_embed < 0:
args.d_embed = args.d_model args.d_embed = args.d_model
...@@ -305,7 +286,8 @@ else: ...@@ -305,7 +286,8 @@ else:
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type, same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) clamp_len=args.clamp_len, sample_softmax=args.sample_softmax,
moe=args.moe, moe_num_expert=args.moe_num_expert, moe_top_k=args.moe_top_k)
model.apply(weights_init) model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_all_param = sum([p.nelement() for p in model.parameters()])
...@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param)) ...@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer. # If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same. # Otherwise, make the mem_len longer and keep the ext_len the same.
...@@ -434,33 +413,15 @@ def evaluate(eval_iter): ...@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break break
ret = model(data, target, *mems) ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.mean() loss = loss.mean()
total_loss += seq_len * loss.float().item() total_loss += seq_len * loss.float().item()
total_len += seq_len total_len += seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode # Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train() model.train()
return total_loss / total_len return total_loss / total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def train(): def train():
...@@ -471,11 +432,6 @@ def train(): ...@@ -471,11 +432,6 @@ def train():
mems = [tuple() for _ in range(args.batch_chunk)] mems = [tuple() for _ in range(args.batch_chunk)]
else: else:
mems = tuple() mems = tuple()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter): for batch, (data, target, seq_len) in enumerate(train_iter):
model.zero_grad() model.zero_grad()
...@@ -487,7 +443,6 @@ def train(): ...@@ -487,7 +443,6 @@ def train():
target_i = target_chunks[i].contiguous() target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i]) ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:] loss, mems[i] = ret[0], ret[1:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
...@@ -497,28 +452,12 @@ def train(): ...@@ -497,28 +452,12 @@ def train():
else: else:
ret = para_model(data, target, *mems) ret = para_model(data, target, *mems)
loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) loss = loss.float().mean().type_as(loss)
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
else: else:
loss.backward() loss.backward()
train_loss += loss.float().item() train_loss += loss.float().item()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if args.fp16: if args.fp16:
optimizer.clip_master_grads(args.clip) optimizer.clip_master_grads(args.clip)
...@@ -557,39 +496,12 @@ def train(): ...@@ -557,39 +496,12 @@ def train():
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else: else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging(log_str) logging(log_str)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss = 0 train_loss = 0
log_start_time = time.time() log_start_time = time.time()
if train_step % args.eval_interval == 0: if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter) val_loss = evaluate(va_iter)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging('-' * 100) logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format( '| valid loss {:5.2f}'.format(
...@@ -599,11 +511,6 @@ def train(): ...@@ -599,11 +511,6 @@ def train():
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else: else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging(log_str) logging(log_str)
logging('-' * 100) logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far. # Save the model if the validation loss is the best we've seen so far.
...@@ -653,7 +560,6 @@ para_model = model.to(device) ...@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data. # Run on test data.
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging('=' * 100) logging('=' * 100)
if args.dataset in ['enwik8', 'text8']: if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
...@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']: ...@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else: else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss))) test_loss, math.exp(test_loss)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging('=' * 100) logging('=' * 100)
...@@ -9,6 +9,10 @@ import torch.nn.functional as F ...@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
CUDA_MINOR = int(torch.version.cuda.split('.')[1]) CUDA_MINOR = int(torch.version.cuda.split('.')[1])
class Projection(nn.Module):
def __init__(self, out_feat, in_feat):
self.weight = nn.Parameter(torch.Tensor(out_feat, in_feat))
class ProjectedAdaptiveLogSoftmax(nn.Module): class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False): keep_order=False):
...@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList() self.out_layers = nn.ModuleList()
self.out_projs = nn.ParameterList() self.out_projs = nn.ModuleList()
if div_val == 1: if div_val == 1:
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
if d_proj != d_embed: if d_proj != d_embed:
self.out_projs.append( self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_embed)) Projection(d_proj, d_embed)
) )
else: else:
self.out_projs.append(None) self.out_projs.append(None)
...@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i = d_embed // (div_val ** i) d_emb_i = d_embed // (div_val ** i)
self.out_projs.append( self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_emb_i)) Projection(d_proj, d_emb_i)
) )
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
...@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if self.n_clusters == 0: if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0]) self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None)
nll = -F.log_softmax(logit, dim=-1) \ nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1) .gather(1, target.unsqueeze(1)).squeeze(1)
else: else:
...@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights.append(weight_i) weights.append(weight_i)
biases.append(bias_i) biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = F.log_softmax(head_logit, dim=1)
...@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if i == 0: if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else: else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight
hidden_i = hidden.index_select(0, indices_i) hidden_i = hidden.index_select(0, indices_i)
......
...@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE): ...@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
top_k=2, top_k=2,
do_lnorm=False, do_lnorm=False,
pre_lnorm=False, pre_lnorm=False,
expert_dp_comm='none' expert_dp_comm='none',
dropout=0.1
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group) top_k=top_k, world_size=world_size, mp_group=mp_group)
self.dropout = nn.Dropout(dropout)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE): ...@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm is not None and self.pre_lnorm: if self.pre_lnorm is not None and self.pre_lnorm:
inp = self.layer_norm(inp) inp = self.layer_norm(inp)
output = super().forward(inp) + inp output = super().forward(inp)
output = self.dropout(output)
output += inp
if self.pre_lnorm is not None and not self.pre_lnorm: if self.pre_lnorm is not None and not self.pre_lnorm:
output = self.layer_norm(output) output = self.layer_norm(output)
return output.reshape(original_shape) return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment