Commit bdb64914 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

remove unused code

parent 5dc62b41
......@@ -7,7 +7,6 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch_sparse
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax, Projection
......@@ -32,361 +31,16 @@ class PositionalEmbedding(nn.Module):
return pos_emb[:,None,:]
# A baseline naive slow implementation
class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFFRaw, self).__init__()
print("MoEPositionwiseFF")
self.top_k = top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, d_inner)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_k / d_inner
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
relu_out = F.relu(gate_top_k_val)
x = self.dropout_middle(relu_out)
W2_select = self.W2[gate_top_k_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ijk,ijkd->ijd', (x, W2_select)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
def my_topk(x, k, inplace=True):
y = x if inplace else x.clone()
top1_val, top1_idx = torch.max(y, dim=-1)
top1_val = top1_val.unsqueeze(-1)
top1_idx = top1_idx.unsqueeze(-1)
if k == 1:
return top1_val, top1_idx
y.scatter_(-1, top1_idx, value=float('-inf'))
top2_val, top2_idx = torch.max(y, dim=-1)
top2_val = top2_val.unsqueeze(-1)
top2_idx = top2_idx.unsqueeze(-1)
top_val = torch.cat((top1_val, top2_val), dim=-1)
top_idx = torch.cat((top1_idx, top2_idx), dim=-1)
return top_val, top_idx
class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(MultiHeadHierarchicalMoEPositionwiseFF, self).__init__()
print("MultiHeadHierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net_W = nn.Parameter(torch.Tensor(d_model, top_block, n_block))
self.block_net_b = nn.Parameter(torch.Tensor(top_block, n_block))
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
# self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
for i in range(self.top_block):
temp = nn.Linear(self.d_model, self.n_block)
self.block_net_W.data[:, i] = temp.weight.data.transpose(0, 1).contiguous()
self.block_net_b.data[i] = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ]
block_val, block_idx = my_topk(block, k=1, inplace=True)
# block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1]
block_val = block_val.squeeze(-1)
block_idx = block_idx.squeeze(-1)
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class HierarchicalMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2):
super(HierarchicalMoEPositionwiseFF, self).__init__()
print("HierarchicalMoEPositionwiseFF")
assert d_inner % n_block == 0
assert top_block in [1, 2]
self.top_block = top_block
self.n_block = n_block
d_block = d_inner // n_block
self.d_block = d_block
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.block_net = nn.Linear(d_model, n_block, bias=True)
self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b1 = nn.Parameter(torch.Tensor(n_block, d_block))
self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_block / n_block
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
# self.scale = 1 / (d_model ** 0.5)
self.reset_parameter()
def reset_parameter(self):
temp = nn.Linear(self.d_model, self.d_inner)
self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model)
self.b1.data = temp.bias.data.view(self.n_block, self.d_block)
temp = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model)
self.b2.data = temp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
block = self.block_net(inp)
# block_val, block_idx = my_topk(block, k=self.top_block)
block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate = F.softmax(block_val, dim=-1)
W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model]
b1_block = self.b1[block_idx] # [.. x top_k x d_block]
x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block]
# x = x + block_val.unsqueeze(-1) # somehow like residual
x = x * gate.unsqueeze(-1)
relu_out = F.relu(x)
relu_out = self.dropout_middle(relu_out)
W2_block = self.W2[block_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class SparsePositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(SparsePositionwiseFF, self).__init__()
print("SparsePositionwiseFF")
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout)
)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
super(MultiHeadPositionwiseFF, self).__init__()
print("MultiHeadPositionwiseFF")
assert d_model % n_head == 0
self.n_head = n_head
d_head = d_model // n_head
self.d_head = d_head
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.q_net = nn.Linear(d_model, d_model)
self.k_weight = nn.Parameter(torch.Tensor(n_head, d_inner, d_head))
self.k_bias = nn.Parameter(torch.Tensor(n_head, d_inner))
self.v_weight = nn.Parameter(torch.Tensor(n_head, d_head, d_inner))
self.v_bias = nn.Parameter(torch.Tensor(n_head, d_head))
#self.o_net = nn.Linear(d_model, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
for i in range(self.n_head):
tmp = nn.Linear(self.d_head, self.d_inner)
self.k_weight.data[i] = tmp.weight.data
self.k_bias.data[i] = tmp.bias.data
tmp = nn.Linear(self.d_inner, self.d_head)
self.v_weight.data[i] = tmp.weight.data
self.v_bias.data[i] = tmp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
head_q = self.q_net(inp)
head_q = head_q.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head]
attn_score = torch.einsum('ibnd,nhd->ibnh', (head_q, self.k_weight)) + self.k_bias # [.. x n_head x d_inner]
attn_score = F.relu(attn_score)
attn_score = self.dropout(attn_score)
attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias
attn_vec = attn_vec.contiguous().view(inp.size(0), inp.size(1), self.d_model)
# core_out = self.o_net(attn_vec)
core_out = self.dropout(attn_vec)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, use_softmax=True):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet_1 = nn.Sequential(
nn.Linear(d_model, d_inner),
nn.Softmax(dim=-1) if use_softmax else nn.ReLU(inplace=True)
)
self.CoreNet_2 = nn.Sequential(
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
......@@ -399,113 +53,18 @@ class PositionwiseFF(nn.Module):
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
relu_out = self.CoreNet_1(self.layer_norm(inp))
core_out = self.CoreNet_2(relu_out)
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
relu_out = self.CoreNet_1(inp)
core_out = self.CoreNet_2(relu_out)
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
# return output, relu_out.detach()
class ExtendedMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(ExtendedMultiHeadAttn, self).__init__()
print("ExtendedMultiHeadAttn")
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head * 2, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
# self.coeff = nn.Parameter(torch.Tensor(n_head, 2))
# nn.init.uniform_(self.coeff, a=-1, b=1)
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
mem_len = mems.size(0)
else:
c = h
mem_len = 0
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(c)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3:
attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
mem2other_attn = attn_mask.new_ones(mem_len, c.size(0))
mem2other_attn[:, :mem_len] = 0
attn_score[:mem_len].masked_fill_(mem2other_attn[:, :, None, None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec_quad = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, attn_vec))
# [qlen x bsz x n_head x d_head x 2]
attn_vecs = torch.cat([attn_vec.unsqueeze(-1), attn_vec_quad.unsqueeze(-1)], dim=-1)
# attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff))
attn_vec = attn_vecs.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head * 2)
attn_vec = attn_vec[mem_len:]
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
......@@ -817,7 +376,6 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return output
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
......@@ -827,7 +385,6 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
)
super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
do_lnorm=True, pre_lnorm=pre_lnorm, activation=activation, dropout=dropout)
#self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = super().forward(x)
......@@ -838,7 +395,6 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
moe_num_expert=kwargs.get('moe_num_expert'),
......@@ -849,10 +405,8 @@ class DecoderLayer(nn.Module):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output
# return output, relu_out
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
......@@ -872,10 +426,8 @@ class RelLearnableDecoderLayer(nn.Module):
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output
# return output, relu_out
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
......@@ -895,11 +447,8 @@ class RelPartialLearnableDecoderLayer(nn.Module):
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output
# return output, relu_out
class AdaptiveEmbedding(nn.Module):
......@@ -1135,7 +684,6 @@ class MemTransformerLM(nn.Module):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = []
# relu_outs = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
......@@ -1149,11 +697,9 @@ class MemTransformerLM(nn.Module):
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
# core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
......@@ -1165,11 +711,9 @@ class MemTransformerLM(nn.Module):
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
# core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
......@@ -1184,11 +728,9 @@ class MemTransformerLM(nn.Module):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
# relu_outs.append(relu_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
......@@ -1206,18 +748,15 @@ class MemTransformerLM(nn.Module):
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
# core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
# relu_outs.append(relu_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
# return core_out, new_mems, relu_outs
def forward(self, data, target, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
......@@ -1228,9 +767,6 @@ class MemTransformerLM(nn.Module):
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
......@@ -1244,10 +780,8 @@ class MemTransformerLM(nn.Module):
if new_mems is None:
return [loss]
# return [relu_outs, loss]
else:
return [loss] + new_mems
# return [relu_outs, loss] + new_mems
if __name__ == '__main__':
import argparse
......
......@@ -4,7 +4,6 @@ import time
import math
import os, sys
import itertools
import pathlib
import numpy as np
......@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel
class AverageMeter(object):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
......@@ -418,9 +392,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
......@@ -440,33 +411,15 @@ def evaluate(eval_iter):
break
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.mean()
total_loss += seq_len * loss.float().item()
total_len += seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train()
return total_loss / total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def train():
......@@ -477,11 +430,6 @@ def train():
mems = [tuple() for _ in range(args.batch_chunk)]
else:
mems = tuple()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter):
model.zero_grad()
......@@ -493,7 +441,6 @@ def train():
target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
optimizer.backward(loss)
......@@ -503,28 +450,12 @@ def train():
else:
ret = para_model(data, target, *mems)
loss, mems = ret[0], ret[1:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss)
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if args.fp16:
optimizer.clip_master_grads(args.clip)
......@@ -563,39 +494,12 @@ def train():
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging(log_str)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.top_k(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss = 0
log_start_time = time.time()
if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
......@@ -605,11 +509,6 @@ def train():
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging(log_str)
logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far.
......@@ -659,7 +558,6 @@ para_model = model.to(device)
# Run on test data.
test_loss = evaluate(te_iter)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
......@@ -667,11 +565,4 @@ if args.dataset in ['enwik8', 'text8']:
else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging('=' * 100)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment