import sys import math import functools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # import torch_sparse sys.path.append('utils') from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from log_uniform_sampler import LogUniformSampler, sample_logits class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) self.register_buffer('inv_freq', inv_freq) def forward(self, pos_seq, bsz=None): sinusoid_inp = torch.ger(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None: return pos_emb[:,None,:].expand(-1, bsz, -1) else: return pos_emb[:,None,:] class MoEPositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64): super(MoEPositionwiseFF, self).__init__() print("MoEPositionwiseFF") self.top_k = top_k self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.gate = nn.Linear(d_model, d_inner) self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model)) self.b2 = nn.Parameter(torch.Tensor(d_model)) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm ratio = top_k / d_inner self.dropout_middle = nn.Dropout(dropout * ratio) self.dropout_final = nn.Dropout(dropout) self.reset_parameter() def reset_parameter(self): temp_Linear = nn.Linear(self.d_inner, self.d_model) self.W2.data = temp_Linear.weight.data.transpose(0, 1) self.b2.data = temp_Linear.bias.data def forward(self, inp): residual = inp if self.pre_lnorm: inp = self.layer_norm(inp) gate = self.gate(inp) gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k] relu_out = F.relu(gate_top_k_val) x = self.dropout_middle(relu_out) W2_select = self.W2[gate_top_k_idx] # [.. x top_k x d_model] core_out = torch.einsum('ijk,ijkd->ijd', (x, W2_select)) + self.b2 # [.. x d_model] core_out = self.dropout_final(core_out) output = core_out + residual if not self.pre_lnorm: output = self.layer_norm(output) return output # return output, relu_out.detach() def my_topk(x, k, inplace=True): y = x if inplace else x.clone() top1_val, top1_idx = torch.max(y, dim=-1) top1_val = top1_val.unsqueeze(-1) top1_idx = top1_idx.unsqueeze(-1) if k == 1: return top1_val, top1_idx y.scatter_(-1, top1_idx, value=float('-inf')) top2_val, top2_idx = torch.max(y, dim=-1) top2_val = top2_val.unsqueeze(-1) top2_idx = top2_idx.unsqueeze(-1) top_val = torch.cat((top1_val, top2_val), dim=-1) top_idx = torch.cat((top1_idx, top2_idx), dim=-1) return top_val, top_idx class MultiHeadHierarchicalMoEPositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2): super(MultiHeadHierarchicalMoEPositionwiseFF, self).__init__() print("MultiHeadHierarchicalMoEPositionwiseFF") assert d_inner % n_block == 0 assert top_block in [1, 2] self.top_block = top_block self.n_block = n_block d_block = d_inner // n_block self.d_block = d_block self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.block_net_W = nn.Parameter(torch.Tensor(d_model, top_block, n_block)) self.block_net_b = nn.Parameter(torch.Tensor(top_block, n_block)) self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model)) self.b1 = nn.Parameter(torch.Tensor(n_block, d_block)) self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model)) self.b2 = nn.Parameter(torch.Tensor(d_model)) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm ratio = top_block / n_block self.dropout_middle = nn.Dropout(dropout * ratio) self.dropout_final = nn.Dropout(dropout) # self.scale = 1 / (d_model ** 0.5) self.reset_parameter() def reset_parameter(self): temp = nn.Linear(self.d_model, self.d_inner) self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model) self.b1.data = temp.bias.data.view(self.n_block, self.d_block) temp = nn.Linear(self.d_inner, self.d_model) self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model) self.b2.data = temp.bias.data for i in range(self.top_block): temp = nn.Linear(self.d_model, self.n_block) self.block_net_W.data[:, i] = temp.weight.data.transpose(0, 1).contiguous() self.block_net_b.data[i] = temp.bias.data def forward(self, inp): residual = inp if self.pre_lnorm: inp = self.layer_norm(inp) block = torch.einsum("ibd,dan->iban", (inp, self.block_net_W)) + self.block_net_b # [.. x top_block x n_block ] block_val, block_idx = my_topk(block, k=1, inplace=True) # block_val, block_idx = torch.topk(block, k=1, dim=-1, largest=True, sorted=False) # [.. x top_k x 1] block_val = block_val.squeeze(-1) block_idx = block_idx.squeeze(-1) gate = F.softmax(block_val, dim=-1) W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model] b1_block = self.b1[block_idx] # [.. x top_k x d_block] x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block] # x = x + block_val.unsqueeze(-1) # somehow like residual x = x * gate.unsqueeze(-1) relu_out = F.relu(x) relu_out = self.dropout_middle(relu_out) W2_block = self.W2[block_idx] # [.. x top_k x d_model] core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model] core_out = self.dropout_final(core_out) output = core_out + residual if not self.pre_lnorm: output = self.layer_norm(output) return output class HierarchicalMoEPositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_block=16, top_block=2): super(HierarchicalMoEPositionwiseFF, self).__init__() print("HierarchicalMoEPositionwiseFF") assert d_inner % n_block == 0 assert top_block in [1, 2] self.top_block = top_block self.n_block = n_block d_block = d_inner // n_block self.d_block = d_block self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.block_net = nn.Linear(d_model, n_block, bias=True) self.W1 = nn.Parameter(torch.Tensor(n_block, d_block, d_model)) self.b1 = nn.Parameter(torch.Tensor(n_block, d_block)) self.W2 = nn.Parameter(torch.Tensor(n_block, d_block, d_model)) self.b2 = nn.Parameter(torch.Tensor(d_model)) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm ratio = top_block / n_block self.dropout_middle = nn.Dropout(dropout * ratio) self.dropout_final = nn.Dropout(dropout) # self.scale = 1 / (d_model ** 0.5) self.reset_parameter() def reset_parameter(self): temp = nn.Linear(self.d_model, self.d_inner) self.W1.data = temp.weight.data.view(self.n_block, self.d_block, self.d_model) self.b1.data = temp.bias.data.view(self.n_block, self.d_block) temp = nn.Linear(self.d_inner, self.d_model) self.W2.data = temp.weight.data.transpose(0, 1).contiguous().view(self.n_block, self.d_block, self.d_model) self.b2.data = temp.bias.data def forward(self, inp): residual = inp if self.pre_lnorm: inp = self.layer_norm(inp) block = self.block_net(inp) # block_val, block_idx = my_topk(block, k=self.top_block) block_val, block_idx = torch.topk(block, k=self.top_block, dim=-1, largest=True, sorted=False) # [.. x top_k] gate = F.softmax(block_val, dim=-1) W1_block = self.W1[block_idx] # [.. x top_k x d_block x d_model] b1_block = self.b1[block_idx] # [.. x top_k x d_block] x = torch.einsum('ibd,ibnhd->ibnh', (inp, W1_block)) + b1_block # [.. x top_k x d_block] # x = x + block_val.unsqueeze(-1) # somehow like residual x = x * gate.unsqueeze(-1) relu_out = F.relu(x) relu_out = self.dropout_middle(relu_out) W2_block = self.W2[block_idx] # [.. x top_k x d_model] core_out = torch.einsum('ibnh,ibnhd->ibd', (x, W2_block)) + self.b2 # [.. x d_model] core_out = self.dropout_final(core_out) output = core_out + residual if not self.pre_lnorm: output = self.layer_norm(output) return output class SparsePositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): super(SparsePositionwiseFF, self).__init__() print("SparsePositionwiseFF") self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.CoreNet_1 = nn.Sequential( nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), nn.Dropout(dropout) ) self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model)) self.b2 = nn.Parameter(torch.Tensor(d_model)) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm self.dropout_final = nn.Dropout(dropout) self.reset_parameter() def reset_parameter(self): temp_Linear = nn.Linear(self.d_inner, self.d_model) self.W2.data = temp_Linear.weight.data.transpose(0, 1) self.b2.data = temp_Linear.bias.data def forward(self, inp): residual = inp if self.pre_lnorm: inp = self.layer_norm(inp) relu_out = self.CoreNet_1(inp).view(-1, self.d_inner) sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out) core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2 core_out = core_out.view(inp.size(0), inp.size(1), self.d_model) core_out = self.dropout_final(core_out) output = core_out + residual if not self.pre_lnorm: output = self.layer_norm(output) return output class MultiHeadPositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, n_head=2): super(MultiHeadPositionwiseFF, self).__init__() print("MultiHeadPositionwiseFF") assert d_model % n_head == 0 self.n_head = n_head d_head = d_model // n_head self.d_head = d_head self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.q_net = nn.Linear(d_model, d_model) self.k_weight = nn.Parameter(torch.Tensor(n_head, d_inner, d_head)) self.k_bias = nn.Parameter(torch.Tensor(n_head, d_inner)) self.v_weight = nn.Parameter(torch.Tensor(n_head, d_head, d_inner)) self.v_bias = nn.Parameter(torch.Tensor(n_head, d_head)) #self.o_net = nn.Linear(d_model, d_model) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm self.dropout = nn.Dropout(dropout) self.reset_parameter() def reset_parameter(self): for i in range(self.n_head): tmp = nn.Linear(self.d_head, self.d_inner) self.k_weight.data[i] = tmp.weight.data self.k_bias.data[i] = tmp.bias.data tmp = nn.Linear(self.d_inner, self.d_head) self.v_weight.data[i] = tmp.weight.data self.v_bias.data[i] = tmp.bias.data def forward(self, inp): residual = inp if self.pre_lnorm: inp = self.layer_norm(inp) head_q = self.q_net(inp) head_q = head_q.view(inp.size(0), inp.size(1), self.n_head, self.d_head) # [.. x n_head x d_head] attn_score = torch.einsum('ibnd,nhd->ibnh', (head_q, self.k_weight)) + self.k_bias # [.. x n_head x d_inner] attn_score = F.relu(attn_score) attn_score = self.dropout(attn_score) attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias attn_vec = attn_vec.contiguous().view(inp.size(0), inp.size(1), self.d_model) # core_out = self.o_net(attn_vec) core_out = self.dropout(attn_vec) output = core_out + residual if not self.pre_lnorm: output = self.layer_norm(output) return output class PositionwiseFF(nn.Module): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, use_softmax=True): super(PositionwiseFF, self).__init__() self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.CoreNet_1 = nn.Sequential( nn.Linear(d_model, d_inner), nn.Softmax(dim=-1) if use_softmax else nn.ReLU(inplace=True) ) self.CoreNet_2 = nn.Sequential( nn.Dropout(dropout), nn.Linear(d_inner, d_model), nn.Dropout(dropout), ) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm def forward(self, inp): if self.pre_lnorm: ##### layer normalization + positionwise feed-forward relu_out = self.CoreNet_1(self.layer_norm(inp)) core_out = self.CoreNet_2(relu_out) ##### residual connection output = core_out + inp else: ##### positionwise feed-forward relu_out = self.CoreNet_1(inp) core_out = self.CoreNet_2(relu_out) ##### residual connection + layer normalization output = self.layer_norm(inp + core_out) return output # return output, relu_out.detach() class ExtendedMultiHeadAttn(nn.Module): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, pre_lnorm=False): super(ExtendedMultiHeadAttn, self).__init__() print("ExtendedMultiHeadAttn") self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) self.drop = nn.Dropout(dropout) self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head * 2, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm # self.coeff = nn.Parameter(torch.Tensor(n_head, 2)) # nn.init.uniform_(self.coeff, a=-1, b=1) def forward(self, h, attn_mask=None, mems=None): ##### multihead attention # [hlen x bsz x n_head x d_head] if mems is not None: c = torch.cat([mems, h], 0) mem_len = mems.size(0) else: c = h mem_len = 0 if self.pre_lnorm: ##### layer normalization c = self.layer_norm(c) head_q = self.q_net(c) head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) head_q = head_q.view(c.size(0), c.size(1), self.n_head, self.d_head) head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) # [qlen x klen x bsz x n_head] attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) attn_score.mul_(self.scale) if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf')) elif attn_mask.dim() == 3: attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf')) mem2other_attn = attn_mask.new_ones(mem_len, c.size(0)) mem2other_attn[:, :mem_len] = 0 attn_score[:mem_len].masked_fill_(mem2other_attn[:, :, None, None].bool(), -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) attn_vec_quad = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, attn_vec)) # [qlen x bsz x n_head x d_head x 2] attn_vecs = torch.cat([attn_vec.unsqueeze(-1), attn_vec_quad.unsqueeze(-1)], dim=-1) # attn_vec = torch.einsum('ibndt,nt->ibnd', (attn_vecs, self.coeff)) attn_vec = attn_vecs.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head * 2) attn_vec = attn_vec[mem_len:] ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = h + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(h + attn_out) return output class MultiHeadAttn(nn.Module): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, pre_lnorm=False): super(MultiHeadAttn, self).__init__() self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) self.drop = nn.Dropout(dropout) self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm def forward(self, h, attn_mask=None, mems=None): ##### multihead attention # [hlen x bsz x n_head x d_head] if mems is not None: c = torch.cat([mems, h], 0) else: c = h if self.pre_lnorm: ##### layer normalization c = self.layer_norm(c) head_q = self.q_net(h) head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) # [qlen x klen x bsz x n_head] attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) attn_score.mul_(self.scale) if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf')) elif attn_mask.dim() == 3: attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = h + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(h + attn_out) return output class RelMultiHeadAttn(nn.Module): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): super(RelMultiHeadAttn, self).__init__() self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) self.drop = nn.Dropout(dropout) self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm def _parallelogram_mask(self, h, w, left=False): mask = torch.ones((h, w)).byte() m = min(h, w) mask[:m,:m] = torch.triu(mask[:m,:m]) mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) if left: return mask else: return mask.flip(0) def _shift(self, x, qlen, klen, mask, left=False): if qlen > 1: zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), device=x.device, dtype=x.dtype) else: zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) if left: mask = mask.flip(1) x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) else: x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) x = x_padded.masked_select(mask[:,:,None,None]) \ .view(qlen, klen, x.size(2), x.size(3)) return x def _rel_shift(self, x, zero_triu=False): zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), device=x.device, dtype=x.dtype) x_padded = torch.cat([zero_pad, x], dim=1) x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) x = x_padded[1:].view_as(x) if zero_triu: ones = torch.ones((x.size(0), x.size(1))) x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] return x def forward(self, w, r, attn_mask=None, mems=None): raise NotImplementedError class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): def __init__(self, *args, **kwargs): super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head #### compute attention score rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head rr_head_q = w_head_q + r_r_bias BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head BD = self._rel_shift(BD) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score = attn_score.float().masked_fill( attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score) elif attn_mask.dim() == 3: attn_score = attn_score.float().masked_fill( attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = w + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(w + attn_out) return output class RelLearnableMultiHeadAttn(RelMultiHeadAttn): def __init__(self, *args, **kwargs): super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): # r_emb: [klen, n_head, d_head], used for term B # r_w_bias: [n_head, d_head], used for term C # r_bias: [klen, n_head], used for term D qlen, bsz = w.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) if klen > r_emb.size(0): r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) r_emb = torch.cat([r_emb_pad, r_emb], 0) r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) r_bias = torch.cat([r_bias_pad, r_bias], 0) else: r_emb = r_emb[-klen:] r_bias = r_bias[-klen:] #### compute attention score rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head BD = self._rel_shift(B_ + D_) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) #### compute attention probability if attn_mask is not None and attn_mask.any().item(): if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf')) elif attn_mask.dim() == 3: attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf')) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) #### compute attention vector attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view( attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: ##### residual connection output = w + attn_out else: ##### residual connection + layer normalization output = self.layer_norm(w + attn_out) return output class DecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(DecoderLayer, self).__init__() self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, dec_attn_mask=None, mems=None): output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, mems=mems) output = self.pos_ff(output) # output, relu_out = self.pos_ff(output) return output # return output, relu_out class RelLearnableDecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(RelLearnableDecoderLayer, self).__init__() self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, attn_mask=dec_attn_mask, mems=mems) output = self.pos_ff(output) # output, relu_out = self.pos_ff(output) return output # return output, relu_out class RelPartialLearnableDecoderLayer(nn.Module): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): super(RelPartialLearnableDecoderLayer, self).__init__() self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, attn_mask=dec_attn_mask, mems=mems) output = self.pos_ff(output) # output, relu_out = self.pos_ff(output) return output # return output, relu_out class AdaptiveEmbedding(nn.Module): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): super(AdaptiveEmbedding, self).__init__() self.n_token = n_token self.d_embed = d_embed self.cutoffs = cutoffs + [n_token] self.div_val = div_val self.d_proj = d_proj self.emb_scale = d_proj ** 0.5 self.cutoff_ends = [0] + self.cutoffs self.emb_layers = nn.ModuleList() self.emb_projs = nn.ParameterList() if div_val == 1: self.emb_layers.append( nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) ) if d_proj != d_embed: self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) else: for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] d_emb_i = d_embed // (div_val ** i) self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) def forward(self, inp): if self.div_val == 1: embed = self.emb_layers[0](inp) if self.d_proj != self.d_embed: embed = F.linear(embed, self.emb_projs[0]) else: param = next(self.parameters()) inp_flat = inp.view(-1) emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) indices_i = mask_i.nonzero().squeeze() if indices_i.numel() == 0: continue inp_i = inp_flat.index_select(0, indices_i) - l_idx emb_i = self.emb_layers[i](inp_i) emb_i = F.linear(emb_i, self.emb_projs[i]) emb_flat.index_copy_(0, indices_i, emb_i) embed = emb_flat.view(*inp.size(), self.d_proj) embed.mul_(self.emb_scale) return embed class MemTransformerLM(nn.Module): def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer( n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm) ) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params() def backward_compatible(self): self.sample_softmax = -1 def _create_params(self): if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) elif self.attn_type == 1: # learnable self.r_emb = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.Tensor( self.n_layer, self.n_head, self.d_head)) self.r_bias = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head)) elif self.attn_type == 2: # absolute standard self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def init_mems(self): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer+1): empty = torch.empty(0, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems def _forward(self, dec_inp, mems=None): qlen, bsz = dec_inp.size() word_emb = self.word_emb(dec_inp) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones(qlen, klen) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 else: dec_attn_mask = torch.triu( word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] hids = [] # relu_outs = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] # core_out, relu_out = layer(core_out, pos_emb, self.r_w_bias, core_out = layer(core_out, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) # relu_outs.append(relu_out) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): if self.clamp_len > 0: r_emb = self.r_emb[i][-self.clamp_len :] r_bias = self.r_bias[i][-self.clamp_len :] else: r_emb, r_bias = self.r_emb[i], self.r_bias[i] mems_i = None if mems is None else mems[i] # core_out, relu_out = layer(core_out, r_emb, self.r_w_bias[i], core_out = layer(core_out, r_emb, self.r_w_bias[i], r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) # relu_outs.append(relu_out) elif self.attn_type == 2: # absolute pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb + pos_emb[-qlen:]) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and i == 0: mems_i += pos_emb[:mlen] # core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) # relu_outs.append(relu_out) elif self.attn_type == 3: core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and mlen > 0: cur_emb = self.r_emb[i][:-qlen] cur_size = cur_emb.size(0) if cur_size < mlen: cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) else: cur_emb = cur_emb[-mlen:] mems_i += cur_emb.view(mlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) # core_out, relu_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) # relu_outs.append(relu_out) core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) return core_out, new_mems # return core_out, new_mems, relu_outs def forward(self, data, target, *mems): # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. if not mems: mems = self.init_mems() tgt_len = target.size(0) hidden, new_mems = self._forward(data, mems=mems) # hidden, new_mems, relu_outs = self._forward(data, mems=mems) # relu_outs = torch.cat([relu_out.unsqueeze(-1) for relu_out in relu_outs], dim=-1) pred_hid = hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: assert self.tie_weight logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.contiguous().view(-1)) loss = loss.view(tgt_len, -1) if new_mems is None: return [loss] # return [relu_outs, loss] else: return [loss] + new_mems # return [relu_outs, loss] + new_mems if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='unit test') parser.add_argument('--n_layer', type=int, default=4, help='') parser.add_argument('--n_rel_layer', type=int, default=4, help='') parser.add_argument('--n_head', type=int, default=2, help='') parser.add_argument('--d_head', type=int, default=2, help='') parser.add_argument('--d_model', type=int, default=200, help='') parser.add_argument('--d_embed', type=int, default=200, help='') parser.add_argument('--d_inner', type=int, default=200, help='') parser.add_argument('--dropout', type=float, default=0.0, help='') parser.add_argument('--cuda', action='store_true', help='') parser.add_argument('--seed', type=int, default=1111, help='') parser.add_argument('--multi_gpu', action='store_true', help='') args = parser.parse_args() device = torch.device("cuda" if args.cuda else "cpu") B = 4 tgt_len, mem_len, ext_len = 36, 36, 0 data_len = tgt_len * 20 args.n_token = 10000 import data_utils data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device) diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len) cutoffs = [args.n_token // 2] tie_projs = [False] + [True] * len(cutoffs) for div_val in [1, 2]: for d_embed in [200, 100]: model = MemTransformerLM(args.n_token, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, dropatt=args.dropout, tie_weight=True, d_embed=d_embed, div_val=div_val, tie_projs=tie_projs, pre_lnorm=True, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, cutoffs=cutoffs, attn_type=0).to(device) print(sum(p.numel() for p in model.parameters())) mems = tuple() for idx, (inp, tgt, seqlen) in enumerate(diter): print('batch {}'.format(idx)) out = model(inp, tgt, *mems) mems = out[1:]