Commit 37d01e9c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

multihead ffn

parent cf8a61d8
...@@ -30,6 +30,124 @@ class PositionalEmbedding(nn.Module): ...@@ -30,6 +30,124 @@ class PositionalEmbedding(nn.Module):
else: else:
return pos_emb[:,None,:] return pos_emb[:,None,:]
class MoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFF, self).__init__()
print("MoEPositionwiseFF")
self.top_k = top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, d_inner)
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self.b2 = nn.Parameter(torch.Tensor(d_model))
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
ratio = top_k / d_inner
self.dropout_middle = nn.Dropout(dropout * ratio)
self.dropout_final = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
temp_Linear = nn.Linear(self.d_inner, self.d_model)
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self.b2.data = temp_Linear.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
relu_out = F.relu(gate_top_k_val)
x = self.dropout_middle(relu_out)
W2_select = self.W2[gate_top_k_idx] # [.. x top_k x d_model]
core_out = torch.einsum('ijk,ijkd->ijd', (x, W2_select)) + self.b2 # [.. x d_model]
core_out = self.dropout_final(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
# return output, relu_out.detach()
class MultiHeadPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2):
super(MultiHeadPositionwiseFF, self).__init__()
print("MultiHeadPositionwiseFF")
assert d_model % n_head == 0
self.n_head = n_head
d_head = d_model / n_head
self.d_head = d_head
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.q_net = nn.Linear(d_model, d_model)
self.k_weight = nn.Parameter(torch.Tensor(n_head, d_inner, d_head))
self.k_bias = nn.Parameter(torch.Tensor(n_head, d_inner))
self.v_weight = nn.Parameter(torch.Tensor(n_head, d_head, d_inner))
self.v_bias = nn.Parameter(torch.Tensor(n_head, d_head))
self.o_net = nn.Linear(d_model, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
for i in range(self.n_head):
tmp = nn.Linear(self.d_head, self.d_inner)
self.k_weight.data[i] = tmp.weight.data
self.k_bias.data[i] = tmp.bias.data
tmp = nn.Linear(self.d_inner, self.d_head)
self.v_weight.data[i] = tmp.weight.data
self.v_bias.data[i] = tmp.bias.data
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
head_q = self.q_net(inp)
head_q = head_q.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head]
attn_score = torch.einsum('ibnd,nhd->ibnh', (head_q, self.k_weight)) + self.k_bias # [.. x n_head x d_inner]
attn_score = F.relu(attn_score)
attn_score = self.dropout(attn_score)
attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias
attn_vec = attn_vec.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.o_net(attn_vec)
core_out = self.dropout(core_out)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
...@@ -69,7 +187,8 @@ class PositionwiseFF(nn.Module): ...@@ -69,7 +187,8 @@ class PositionwiseFF(nn.Module):
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(inp + core_out) output = self.layer_norm(inp + core_out)
return output, relu_out.detach() return output
# return output, relu_out.detach()
class ExtendedMultiHeadAttn(nn.Module): class ExtendedMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
...@@ -125,14 +244,14 @@ class ExtendedMultiHeadAttn(nn.Module): ...@@ -125,14 +244,14 @@ class ExtendedMultiHeadAttn(nn.Module):
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None], -float('inf')) attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None], -float('inf')) attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
mem2other_attn = attn_mask.new_ones(mem_len, c.size(0)) mem2other_attn = attn_mask.new_ones(mem_len, c.size(0))
mem2other_attn[:, :mem_len] = 0 mem2other_attn[:, :mem_len] = 0
attn_score[:mem_len].masked_fill_(mem2other_attn[:, :, None, None], -float('inf')) attn_score[:mem_len].masked_fill_(mem2other_attn[:, :, None, None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
...@@ -211,9 +330,9 @@ class MultiHeadAttn(nn.Module): ...@@ -211,9 +330,9 @@ class MultiHeadAttn(nn.Module):
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
...@@ -358,10 +477,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -358,10 +477,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill( attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill( attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
...@@ -444,9 +563,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -444,9 +563,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
#### compute attention probability #### compute attention probability
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
...@@ -478,16 +597,18 @@ class DecoderLayer(nn.Module): ...@@ -478,16 +597,18 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output, relu_out = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output, relu_out return output
# return output, relu_out
class RelLearnableDecoderLayer(nn.Module): class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -496,7 +617,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -496,7 +617,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -504,9 +625,11 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -504,9 +625,11 @@ class RelLearnableDecoderLayer(nn.Module):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output, relu_out = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output, relu_out return output
# return output, relu_out
class RelPartialLearnableDecoderLayer(nn.Module): class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -515,7 +638,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -515,7 +638,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = MultiHeadPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
...@@ -523,9 +646,11 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -523,9 +646,11 @@ class RelPartialLearnableDecoderLayer(nn.Module):
output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems)
output, relu_out = self.pos_ff(output) output = self.pos_ff(output)
# output, relu_out = self.pos_ff(output)
return output, relu_out return output
# return output, relu_out
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
...@@ -758,7 +883,7 @@ class MemTransformerLM(nn.Module): ...@@ -758,7 +883,7 @@ class MemTransformerLM(nn.Module):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = [] hids = []
relu_outs = [] # relu_outs = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -772,10 +897,11 @@ class MemTransformerLM(nn.Module): ...@@ -772,10 +897,11 @@ class MemTransformerLM(nn.Module):
hids.append(core_out) hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias, # core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias,
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out) hids.append(core_out)
relu_outs.append(relu_out) # relu_outs.append(relu_out)
elif self.attn_type == 1: # learnable elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
hids.append(core_out) hids.append(core_out)
...@@ -787,10 +913,11 @@ class MemTransformerLM(nn.Module): ...@@ -787,10 +913,11 @@ class MemTransformerLM(nn.Module):
r_emb, r_bias = self.r_emb[i], self.r_bias[i] r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i], # core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i],
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out) hids.append(core_out)
relu_outs.append(relu_out) # relu_outs.append(relu_out)
elif self.attn_type == 2: # absolute elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -805,10 +932,11 @@ class MemTransformerLM(nn.Module): ...@@ -805,10 +932,11 @@ class MemTransformerLM(nn.Module):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0: if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen] mems_i += pos_emb[:mlen]
core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask, # core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out) hids.append(core_out)
relu_outs.append(relu_out) # relu_outs.append(relu_out)
elif self.attn_type == 3: elif self.attn_type == 3:
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
...@@ -826,16 +954,18 @@ class MemTransformerLM(nn.Module): ...@@ -826,16 +954,18 @@ class MemTransformerLM(nn.Module):
mems_i += cur_emb.view(mlen, 1, -1) mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask, # core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask,
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out) hids.append(core_out)
relu_outs.append(relu_out) # relu_outs.append(relu_out)
core_out = self.drop(core_out) core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen) new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems, relu_outs return core_out, new_mems
# return core_out, new_mems, relu_outs
def forward(self, data, target, *mems): def forward(self, data, target, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted. # nn.DataParallel does not allow size(0) tensors to be broadcasted.
...@@ -845,7 +975,8 @@ class MemTransformerLM(nn.Module): ...@@ -845,7 +975,8 @@ class MemTransformerLM(nn.Module):
if not mems: mems = self.init_mems() if not mems: mems = self.init_mems()
tgt_len = target.size(0) tgt_len = target.size(0)
hidden, new_mems, relu_outs = self._forward(data, mems=mems) hidden, new_mems = self._forward(data, mems=mems)
# hidden, new_mems, relu_outs = self._forward(data, mems=mems)
# relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1) # relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1)
...@@ -860,9 +991,11 @@ class MemTransformerLM(nn.Module): ...@@ -860,9 +991,11 @@ class MemTransformerLM(nn.Module):
loss = loss.view(tgt_len, -1) loss = loss.view(tgt_len, -1)
if new_mems is None: if new_mems is None:
return [relu_outs, loss] return [loss]
# return [relu_outs, loss]
else: else:
return [relu_outs, loss] + new_mems return [loss] + new_mems
# return [relu_outs, loss] + new_mems
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
......
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
import math import math
import os, sys import os, sys
import itertools import itertools
import pathlib
import numpy as np import numpy as np
...@@ -411,7 +412,9 @@ logging('#non emb params = {}'.format(args.n_nonemb_param)) ...@@ -411,7 +412,9 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
avg_nnzs = None # avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer. # If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same. # Otherwise, make the mem_len longer and keep the ext_len the same.
...@@ -430,22 +433,34 @@ def evaluate(eval_iter): ...@@ -430,22 +433,34 @@ def evaluate(eval_iter):
if args.max_eval_steps > 0 and i >= args.max_eval_steps: if args.max_eval_steps > 0 and i >= args.max_eval_steps:
break break
ret = model(data, target, *mems) ret = model(data, target, *mems)
# loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
relu_outs, loss, mems = ret[0], ret[1], ret[2:] # relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.mean() loss = loss.mean()
total_loss += seq_len * loss.float().item() total_loss += seq_len * loss.float().item()
total_len += seq_len total_len += seq_len
nnzs = [(relu_out > 0).sum().float().item() / relu_out.numel() for relu_out in relu_outs]
if avg_nnzs is None: # acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
avg_nnzs = [AverageMeter() for i in range(len(nnzs))] # if avg_nnzs is None:
for i in range(len(nnzs)): # n_layer = len(acts)
avg_nnzs[i].update(nnzs[i]) # avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode # Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train() model.train()
return total_loss / total_len, avg_nnzs return total_loss / total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def train(): def train():
...@@ -457,7 +472,9 @@ def train(): ...@@ -457,7 +472,9 @@ def train():
else: else:
mems = tuple() mems = tuple()
avg_nnzs = None # avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter): for batch, (data, target, seq_len) in enumerate(train_iter):
...@@ -469,8 +486,8 @@ def train(): ...@@ -469,8 +486,8 @@ def train():
data_i = data_chunks[i].contiguous() data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous() target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i]) ret = para_model(data_i, target_i, *mems[i])
# loss, mems[i] = ret[0], ret[1:] loss, mems[i] = ret[0], ret[1:]
relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:] # relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
...@@ -479,19 +496,29 @@ def train(): ...@@ -479,19 +496,29 @@ def train():
train_loss += loss.float().item() train_loss += loss.float().item()
else: else:
ret = para_model(data, target, *mems) ret = para_model(data, target, *mems)
# loss, mems = ret[0], ret[1:] loss, mems = ret[0], ret[1:]
relu_outs, loss, mems = ret[0], ret[1], ret[2:] # relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss = loss.float().mean().type_as(loss) loss = loss.float().mean().type_as(loss)
if args.fp16: if args.fp16:
optimizer.backward(loss) optimizer.backward(loss)
else: else:
loss.backward() loss.backward()
train_loss += loss.float().item() train_loss += loss.float().item()
nnzs = [(relu_out > 0).sum().float().item() / relu_out.numel() for relu_out in relu_outs] # acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
if avg_nnzs is None: # # nnzs = [act.sum().item() / act.numel() for act in acts]
avg_nnzs = [AverageMeter() for i in range(len(nnzs))] # if avg_nnzs is None:
for i in range(len(nnzs)): # n_layer = len(acts)
avg_nnzs[i].update(nnzs[i]) # avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if args.fp16: if args.fp16:
optimizer.clip_master_grads(args.clip) optimizer.clip_master_grads(args.clip)
...@@ -530,17 +557,39 @@ def train(): ...@@ -530,17 +557,39 @@ def train():
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else: else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))] # final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
for i in range(len(avg_nnzs)): # log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
avg_nnzs[i].reset() # sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
log_str += " | avg nnz %.2f | max nnz %.2f" % (sum(final_avg_nnzs)/len(final_avg_nnzs)*100, max(final_avg_nnzs)*100) # max(final_avg_nnzs)*100,
# )
logging(log_str) logging(log_str)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss = 0 train_loss = 0
log_start_time = time.time() log_start_time = time.time()
if train_step % args.eval_interval == 0: if train_step % args.eval_interval == 0:
val_loss, eval_avg_nnzs = evaluate(va_iter) val_loss = evaluate(va_iter)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging('-' * 100) logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format( '| valid loss {:5.2f}'.format(
...@@ -550,8 +599,11 @@ def train(): ...@@ -550,8 +599,11 @@ def train():
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else: else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))] # final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
log_str += " | mean nnz %.2f | max nnz %.2f" % (sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100, max(final_eval_avg_nnzs)*100) # log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging(log_str) logging(log_str)
logging('-' * 100) logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far. # Save the model if the validation loss is the best we've seen so far.
...@@ -600,7 +652,8 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: ...@@ -600,7 +652,8 @@ with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model = model.to(device) para_model = model.to(device)
# Run on test data. # Run on test data.
test_loss, eval_avg_nnzs = evaluate(te_iter) test_loss = evaluate(te_iter)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging('=' * 100) logging('=' * 100)
if args.dataset in ['enwik8', 'text8']: if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
...@@ -609,6 +662,10 @@ else: ...@@ -609,6 +662,10 @@ else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss))) test_loss, math.exp(test_loss)))
final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))] # final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
logging(" | mean nnz %.2f | max nnz %.2f" % (sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100, max(final_eval_avg_nnzs)*100)) # log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging('=' * 100) logging('=' * 100)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment