Commit 65c49bb2 authored by thomwolf's avatar thomwolf
Browse files

adding TF 2.0 adaptive softmax with logits + loss outputs

parent 39c38b2e
...@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def call(self, inputs, training=False): def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs input_ids = inputs
......
...@@ -30,8 +30,8 @@ import numpy as np ...@@ -30,8 +30,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
...@@ -49,55 +49,56 @@ def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): ...@@ -49,55 +49,56 @@ def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class PositionalEmbedding(nn.Module): class TFPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, demb): def __init__(self, demb, **kwargs):
super(PositionalEmbedding, self).__init__() super(TFPositionalEmbedding, self).__init__(**kwargs)
self.demb = demb self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) def call(self, pos_seq, bsz=None):
self.register_buffer('inv_freq', inv_freq) sinusoid_inp = tf.einsum('i,j->ij', pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None: if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1) return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
else: else:
return pos_emb[:,None,:] return pos_emb[:, None, :]
class PositionwiseFF(nn.Module): class TFPositionwiseFF(tf.keras.layers.Layer):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, **kwargs):
super(PositionwiseFF, self).__init__() super(TFPositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model self.d_model = d_model
self.d_inner = d_inner self.d_inner = d_inner
self.dropout = dropout self.dropout = dropout
self.CoreNet = nn.Sequential( self.layer_1 = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu, name='CoreNet_._0')
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), self.drop_1 = tf.keras.layers.Dropout(dropout)
nn.Dropout(dropout), self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._2')
nn.Linear(d_inner, d_model), self.drop_2 = tf.keras.layers.Dropout(dropout)
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm')
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
def forward(self, inp): def call(self, inp, training=False):
if self.pre_lnorm: if self.pre_lnorm:
##### layer normalization + positionwise feed-forward ##### layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp)) core_out = self.layer_norm(inp)
core_out = self.layer_1(core_out)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection ##### residual connection
output = core_out + inp output = core_out + inp
else: else:
##### positionwise feed-forward ##### positionwise feed-forward
core_out = self.CoreNet(inp) core_out = self.layer_1(inp)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(inp + core_out) output = self.layer_norm(inp + core_out)
...@@ -105,102 +106,11 @@ class PositionwiseFF(nn.Module): ...@@ -105,102 +106,11 @@ class PositionwiseFF(nn.Module):
return output return output
class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None, head_mask=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
outputs = [h + attn_out]
else:
##### residual connection + layer normalization
outputs = [self.layer_norm(h + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False): r_r_bias=None, r_w_bias=None, output_attentions=False, **kwargs):
super(RelMultiHeadAttn, self).__init__() super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.n_head = n_head self.n_head = n_head
...@@ -208,91 +118,60 @@ class RelMultiHeadAttn(nn.Module): ...@@ -208,91 +118,60 @@ class RelMultiHeadAttn(nn.Module):
self.d_head = d_head self.d_head = d_head
self.dropout = dropout self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) self.qkv_net = tf.keras.layers.Dense(3 * n_head * d_head, use_bias=False, name='qkv_net')
self.drop = nn.Dropout(dropout) self.drop = tf.keras.layers.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt) self.dropatt = tf.keras.layers.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.o_net = tf.keras.layers.Dense(d_model, use_bias=False, name='o_net')
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm')
self.scale = 1 / (d_head ** 0.5) self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared if r_r_bias is not None and r_w_bias is not None: # Biases are shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias self.r_w_bias = r_w_bias
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m,:m] = torch.triu(mask[:m,:m])
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else: else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) self.r_r_bias = None
self.r_w_bias = None
if left: self.r_net = tf.keras.layers.Dense(self.n_head * self.d_head, use_bias=False, name='r_net')
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:,:,None,None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] def build(self, input_shape):
x_padded = x_padded.view(*x_padded_shape) if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared
self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head),
trainable=True,
name='r_r_bias')
self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head),
trainable=True,
name='r_w_bias')
super(TFRelPartialLearnableMultiHeadAttn, self).build(input_shape)
x = x_padded[1:].view_as(x) def _rel_shift(self, x):
x_size = shape_list(x)
if zero_triu: x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
ones = torch.ones((x.size(0), x.size(1))) x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x return x
def forward(self, w, r, attn_mask=None, mems=None): def call(self, inputs, training=False):
raise NotImplementedError w, r, attn_mask, mems, head_mask = inputs
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None: if mems is not None:
cat = torch.cat([mems, w], 0) cat = tf.concat([mems, w], 0)
if self.pre_lnorm: if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat)) w_heads = self.qkv_net(self.layer_norm(cat))
else: else:
w_heads = self.qkv_net(cat) w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r) r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
w_head_q = w_head_q[-qlen:] w_head_q = w_head_q[-qlen:]
else: else:
if self.pre_lnorm: if self.pre_lnorm:
...@@ -301,56 +180,52 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -301,56 +180,52 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
w_heads = self.qkv_net(w) w_heads = self.qkv_net(w)
r_head_k = self.r_net(r) r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
klen = w_head_k.size(0) klen = shape_list(w_head_k)[0]
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
#### compute attention score #### compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias rr_head_q = w_head_q + self.r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD) BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_score = AC + BD attn_score = AC + BD
attn_score.mul_(self.scale) attn_score = attn_score * self.scale
#### compute attention probability #### compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None:
attn_mask = (attn_mask == 1) # Switch to bool attn_mask_t = attn_mask[:, :, None, None]
if attn_mask.dim() == 2: attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1) attn_prob = tf.nn.softmax(attn_score, axis=1)
attn_prob = self.dropatt(attn_prob) attn_prob = self.dropatt(attn_prob, training=training)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
attn_prob = attn_prob * head_mask attn_prob = attn_prob * head_mask
#### compute attention vector #### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head] # [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view( attn_vec_sizes = shape_list(attn_vec)
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) attn_vec = tf.reshape(attn_vec,
(attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
##### linear projection ##### linear projection
attn_out = self.o_net(attn_vec) attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out) attn_out = self.drop(attn_out, training=training)
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection ##### residual connection
...@@ -364,166 +239,40 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -364,166 +239,40 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return outputs return outputs
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None: class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
tgt_len=None, ext_len=None, mem_len=None,
dropatt=0., pre_lnorm=False,
r_w_bias=None,
r_r_bias=None,
output_attentions=False,
**kwargs): **kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__() super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs)
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = TFRelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, tgt_len=tgt_len, ext_len=ext_len,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm,
pre_lnorm=kwargs.get('pre_lnorm')) r_w_bias=r_w_bias, r_r_bias=r_r_bias,
output_attentions=output_attentions, name='dec_attn')
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=pre_lnorm, name='pos_ff')
attn_outputs = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask, def call(self, inputs, training=False):
mems=mems, head_mask=head_mask) dec_inp, r, dec_attn_mask, mems, head_mask = inputs
ff_output = self.pos_ff(attn_outputs[0]) attn_outputs = self.dec_attn([dec_inp, r, dec_attn_mask,
mems, head_mask], training=training)
ff_output = self.pos_ff(attn_outputs[0], training=training)
outputs = [ff_output] + attn_outputs[1:] outputs = [ff_output] + attn_outputs[1:]
return outputs return outputs
class TFAdaptiveEmbedding(tf.keras.layers.Layer):
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False): sample_softmax=False, **kwargs):
super(AdaptiveEmbedding, self).__init__() super(TFAdaptiveEmbedding, self).__init__(**kwargs)
self.n_token = n_token self.n_token = n_token
self.d_embed = d_embed self.d_embed = d_embed
...@@ -536,188 +285,53 @@ class AdaptiveEmbedding(nn.Module): ...@@ -536,188 +285,53 @@ class AdaptiveEmbedding(nn.Module):
self.cutoff_ends = [0] + self.cutoffs self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList() self.emb_layers = []
self.emb_projs = nn.ParameterList() self.emb_projs = []
if div_val == 1: if div_val == 1:
self.emb_layers.append( raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
else: else:
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i) d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) self.emb_layers.append(tf.keras.layers.Embedding(r_idx-l_idx, d_emb_i, name='emb_layers_._{}'.format(i)))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
def build(self, input_shape):
for i in range(len(self.cutoffs)):
d_emb_i = self.d_embed // (self.div_val ** i)
self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj),
trainable=True,
name='emb_projs._{}'.format(i)))
super(TFAdaptiveEmbedding, self).build(input_shape)
def forward(self, inp): def call(self, inp):
if self.div_val == 1: if self.div_val == 1:
embed = self.emb_layers[0](inp) raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else: else:
param = next(self.parameters()) inp_flat = tf.reshape(inp, (-1,))
inp_flat = inp.view(-1) emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj])
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx
emb_i = self.emb_layers[i](inp_i) emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i]) emb_i = tf.einsum('id,de->ie', emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i) mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(tf.shape(emb_flat), dtype=tf.int64))
embed_shape = inp.size() + (self.d_proj,) embed_shape = shape_list(inp) + [self.d_proj]
embed = emb_flat.view(embed_shape) embed = tf.reshape(emb_flat, embed_shape)
embed.mul_(self.emb_scale) embed *= self.emb_scale
return embed return embed
class TransfoXLPreTrainedModel(PreTrainedModel): class TFTransfoXLMainLayer(tf.keras.layers.Layer):
""" An abstract class to handle weights initialization and def __init__(self, config, **kwargs):
a simple interface for dowloading and loading pretrained models. super(TFTransfoXLMainLayer, self).__init__(**kwargs)
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
def _init_weight(self, weight):
if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal':
nn.init.normal_(weight, 0.0, self.config.init_std)
def _init_bias(self, bias):
nn.init.constant_(bias, 0.0)
def _init_weights(self, m):
""" Initialize the weights.
"""
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
self._init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
self._init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
self._init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
self._init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
self._init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None:
self._init_bias(m.bias)
else:
if hasattr(m, 'r_emb'):
self._init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
self._init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
self._init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
self._init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens):
pass
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TransfoXLModel(TransfoXLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def __init__(self, config):
super(TransfoXLModel, self).__init__(config)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
...@@ -727,11 +341,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -727,11 +341,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.d_model = config.d_model self.d_model = config.d_model
self.n_head = config.n_head self.n_head = config.n_head
self.d_head = config.d_head self.d_head = config.d_head
self.untie_r = config.untie_r
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, self.word_emb = TFAdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val) div_val=config.div_val, name='word_emb')
self.drop = nn.Dropout(config.dropout) self.drop = tf.keras.layers.Dropout(config.dropout)
self.n_layer = config.n_layer self.n_layer = config.n_layer
...@@ -742,61 +357,41 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -742,61 +357,41 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.attn_type = config.attn_type self.attn_type = config.attn_type
if not config.untie_r: self.layers = []
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.layers = nn.ModuleList()
if config.attn_type == 0: # the default attention if config.attn_type == 0: # the default attention
for i in range(config.n_layer): for i in range(config.n_layer):
self.layers.append( self.layers.append(
RelPartialLearnableDecoderLayer( TFRelPartialLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type == 1: # learnable embeddings
for i in range(config.n_layer):
self.layers.append(
RelLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if self.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias, r_r_bias=None if self.untie_r else self.r_r_bias,
output_attentions=self.output_attentions) output_attentions=self.output_attentions,
) name='layers_._{}'.format(i))
elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer):
self.layers.append(
DecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
) )
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self.same_length = config.same_length self.same_length = config.same_length
self.clamp_len = config.clamp_len self.clamp_len = config.clamp_len
if self.attn_type == 0: # default attention if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model) self.pos_emb = TFPositionalEmbedding(self.d_model, name='pos_emb')
elif self.attn_type == 1: # learnable else: # learnable embeddings and absolute embeddings
self.r_emb = nn.Parameter(torch.FloatTensor( raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.FloatTensor( def build(self, input_shape):
self.n_layer, self.max_klen, self.n_head)) if not self.untie_r:
elif self.attn_type == 2: # absolute standard self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head),
self.pos_emb = PositionalEmbedding(self.d_model) initializer='zeros',
elif self.attn_type == 3: # absolute deeper SA trainable=True,
self.r_emb = nn.Parameter(torch.FloatTensor( name='r_w_bias')
self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head),
initializer='zeros',
self.init_weights() trainable=True,
name='r_r_bias')
super(TFTransfoXLMainLayer, self).build(input_shape)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb return self.word_emb
...@@ -810,16 +405,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -810,16 +405,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.ext_len = ext_len self.ext_len = ext_len
def _prune_heads(self, heads): def _prune_heads(self, heads):
logger.info("Head pruning is not implemented for Transformer-XL model") raise NotImplementedError
pass
def init_mems(self, data): def init_mems(self, data):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
param = next(self.parameters())
for i in range(self.n_layer): for i in range(self.n_layer):
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model, empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model])
dtype=param.dtype, device=param.device)
mems.append(empty) mems.append(empty)
return mems return mems
...@@ -838,164 +430,211 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -838,164 +430,211 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# will be used as the extended context. Hence, we only cache # will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len` # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`. # to `mlen + qlen - self.ext_len`.
with torch.no_grad(): new_mems = []
new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len)
end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len)
beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)):
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0) cat = tf.concat([mems[i], hids[i]], axis=0)
new_mems.append(cat[beg_idx:end_idx].detach()) tf.stop_gradient(cat)
new_mems.append(cat[beg_idx:end_idx])
return new_mems return new_mems
def _forward(self, dec_inp, mems=None, head_mask=None): def call(self, inputs, training=False):
qlen, bsz = dec_inp.size() if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
mems, head_mask = None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None
head_mask = inputs[2] if len(inputs) > 2 else None
assert len(inputs) <= 3, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None)
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 3, "Too many inputs."
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = tf.transpose(input_ids, perm=(1, 0))
if mems is None:
mems = self.init_mems(input_ids)
qlen, bsz = shape_list(input_ids)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if not head_mask is None:
if head_mask.dim() == 1: raise NotImplementedError
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
word_emb = self.word_emb(dec_inp) word_emb = self.word_emb(input_ids)
mlen = mems[0].size(0) if mems is not None else 0 mlen = shape_list(mems[0])[0] if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen])
dec_attn_mask = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length: if self.same_length:
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) mask_l = tf.linalg.band_part(attn_mask, -1, 0)
mask_len = klen - self.mem_len dec_attn_mask = tf.concat([dec_attn_mask[:, :qlen] + mask_l - mask_dia,
if mask_len > 0: dec_attn_mask[:, qlen:]], 1)
mask_shift_len = qlen - mask_len # ::: PyTorch masking code for reference :::
else: # if self.same_length:
mask_shift_len = qlen # all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
dec_attn_mask = (torch.triu(all_ones, 1+mlen) # mask_len = klen - self.mem_len
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 # if mask_len > 0:
else: # mask_shift_len = qlen - mask_len
dec_attn_mask = torch.triu( # else:
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] # mask_shift_len = qlen
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
# else:
# dec_attn_mask = torch.triu(
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids = [] hids = []
attentions = [] attentions = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = tf.range(klen-1, -1, -1.0)
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq = tf.minimum(pos_seq, self.clamp_len)
pos_emb = self.pos_emb(pos_seq) pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:]) core_out = self.drop(word_emb, training=training)
pos_emb = self.drop(pos_emb, training=training)
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 3:
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out) hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0: layer_outputs = layer([core_out, pos_emb, dec_attn_mask,
cur_emb = self.r_emb[i][:-qlen] mems_i, head_mask[i]], training=training)
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0] core_out = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
attentions.append(layer_outputs[1]) attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out = self.drop(core_out) core_out = self.drop(core_out, training=training)
new_mems = self._update_mems(hids, mems, mlen, qlen) new_mems = self._update_mems(hids, mems, mlen, qlen)
# We transpose back here to shape [bsz, len, hidden_dim] # We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems] outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
if self.output_hidden_states: if self.output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim] # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out) hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids) hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
outputs.append(hids) outputs.append(hids)
if self.output_attentions: if self.output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs.append(attentions) outputs.append(attentions)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions) return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
def forward(self, input_ids, mems=None, head_mask=None):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = input_ids.transpose(0, 1).contiguous()
if mems is None: class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
mems = self.init_mems(input_ids) """ An abstract class to handle weights initialization and
outputs = self._forward(input_ids, mems=mems, head_mask=head_mask) a simple interface for dowloading and loading pretrained models.
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_transfo_xl_pt_weights_in_tf2
base_model_prefix = "transformer"
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.tf.keras.layers.Layer`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.tf.keras.layers.Layer`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFTransfoXLMainLayer(config, name='transformer')
def call(self, inputs, training=False, **kwargs):
outputs = self.transformer(inputs, training=training, **kwargs)
return outputs
@add_start_docstrings("""The Transformer-XL Model with a language modeling head on top @add_start_docstrings("""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)""", (adaptive softmax with weights tied to the adaptive input embeddings)""",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
r""" r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling. Labels for language modeling.
...@@ -1032,46 +671,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1032,46 +671,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(TransfoXLLMHeadModel, self).__init__(config) super(TFTransfoXLLMHeadModel, self).__init__(config)
self.transformer = TransfoXLModel(config) self.transformer = TFTransfoXLMainLayer(config, name='transformer')
self.sample_softmax = config.sample_softmax self.sample_softmax = config.sample_softmax
# use sampled softmax # use sampled softmax
if config.sample_softmax > 0: if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.n_token) raise NotImplementedError
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax) # use adaptive softmax (including standard softmax)
else: else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, self.crit = TFAdaptiveSoftmaxMask(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val) config.cutoffs, div_val=config.div_val, name='crit')
self.init_weights()
self.tie_weights()
def tie_weights(self):
"""
Run this to be sure output and input (adaptive) softmax weights are tied
"""
# sampled softmax
if self.sample_softmax > 0:
if self.config.tie_weight:
self.out_layer.weight = self.transformer.word_emb.weight
# adaptive softmax (including standard softmax)
else:
if self.config.tie_weight:
for i in range(len(self.crit.out_layers)):
self._tie_or_clone_weights(self.crit.out_layers[i],
self.transformer.word_emb.emb_layers[i])
if self.config.tie_projs:
for i, tie_proj in enumerate(self.config.tie_projs):
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
if self.config.torchscript:
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
else:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
elif tie_proj and self.config.div_val != 1:
if self.config.torchscript:
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
else:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) self.transformer.reset_length(tgt_len, ext_len, mem_len)
...@@ -1079,30 +688,36 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1079,30 +688,36 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def init_mems(self, data): def init_mems(self, data):
return self.transformer.init_mems(data) return self.transformer.init_mems(data)
def forward(self, input_ids, mems=None, head_mask=None, labels=None): def call(self, inputs, training=False):
bsz = input_ids.size(0) if not isinstance(inputs, (dict, tuple, list)):
tgt_len = input_ids.size(1) input_ids = inputs
mems, head_mask, labels = None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None
head_mask = inputs[2] if len(inputs) > 2 else None
labels = inputs[3] if len(inputs) > 3 else None
assert len(inputs) <= 4, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None)
head_mask = inputs.get('head_mask', None)
labels = inputs.get('labels', None)
assert len(inputs) <= 4, "Too many inputs."
bsz, tgt_len = shape_list(input_ids)[:2]
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask) transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training)
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:] outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and self.training: if self.sample_softmax > 0 and training:
assert self.config.tie_weight raise NotImplementedError
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
outputs = [softmax_output] + outputs
if labels is not None:
# TODO: This is not implemented
raise NotImplementedError
else: else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) # pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
if labels is None: softmax_output = self.crit([pred_hid, labels], training=training)
softmax_output = softmax_output.view(bsz, tgt_len, -1) # softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
outputs = [softmax_output] + outputs outputs = [softmax_output] + outputs
else:
softmax_output = softmax_output.view(bsz, tgt_len)
outputs = [softmax_output, None] + outputs
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) return outputs # logits, new_mems, (all hidden states), (all attentions)
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
"""
from collections import defaultdict
import numpy as np
import tensorflow as tf
from .modeling_tf_utils import shape_list
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False, **kwargs):
super(TFAdaptiveSoftmaxMask, self).__init__(**kwargs)
self.n_token = n_token
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = cutoffs + [n_token]
self.cutoff_ends = [0] + self.cutoffs
self.div_val = div_val
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
self.keep_order = keep_order
self.out_layers = []
self.out_projs = []
def build(self, input_shape):
if self.n_clusters > 0:
self.cluster_weight = self.add_weight(shape=(self.n_clusters, self.d_embed),
initializer='zeros',
trainable=True,
name='cluster_weight')
self.cluster_bias = self.add_weight(shape=(self.n_clusters,),
initializer='zeros',
trainable=True,
name='cluster_bias')
if self.div_val == 1:
for i in range(len(self.cutoffs)):
if self.d_proj != self.d_embed:
weight = self.add_weight(shape=(self.d_embed, self.d_proj),
initializer='zeros',
trainable=True,
name='out_projs_._{}'.format(i))
self.out_projs.append(weight)
else:
self.out_projs.append(None)
weight = self.add_weight(shape=(self.n_token, self.d_embed,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._weight'.format(i))
bias = self.add_weight(shape=(self.n_token,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._bias'.format(i))
self.out_layers.append((weight, bias))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = self.d_embed // (self.div_val ** i)
weight = self.add_weight(shape=(d_emb_i, self.d_proj),
initializer='zeros',
trainable=True,
name='out_projs_._{}'.format(i))
self.out_projs.append(weight)
weight = self.add_weight(shape=(r_idx-l_idx, d_emb_i,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._weight'.format(i))
bias = self.add_weight(shape=(r_idx-l_idx,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._bias'.format(i))
self.out_layers.append((weight, bias))
super(TFAdaptiveSoftmaxMask, self).build(input_shape)
@staticmethod
def _logit(x, W, b, proj=None):
y = x
if proj is not None:
y = tf.einsum('ibd,ed->ibe', y, proj)
return tf.einsum('ibd,nd->ibn', y, W) + b
@staticmethod
def _gather_logprob(logprob, target):
lp_size = tf.shape(logprob)
r = tf.range(lp_size[0])
idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx)
def call(self, inputs, return_mean=True, training=False):
hidden, target = inputs
head_logprob = 0
if self.n_clusters == 0:
softmax_b = tf.get_variable('bias', [n_token], initializer=tf.zeros_initializer())
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
if target is not None:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
out = tf.nn.log_softmax(output, axis=-1)
else:
hidden_sizes = shape_list(hidden)
out = []
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
if target is not None:
mask = (target >= l_idx) & (target < r_idx)
mask_idx = tf.where(mask)
cur_target = tf.boolean_mask(target, mask) - l_idx
if self.div_val == 1:
cur_W = self.out_layers[0][0][l_idx:r_idx]
cur_b = self.out_layers[0][1][l_idx:r_idx]
else:
cur_W = self.out_layers[i][0]
cur_b = self.out_layers[i][1]
if i == 0:
cur_W = tf.concat([cur_W, self.cluster_weight], 0)
cur_b = tf.concat([cur_b, self.cluster_bias], 0)
head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
head_logprob = tf.nn.log_softmax(head_logit)
out.append(head_logprob[..., :self.cutoffs[0]])
if target is not None:
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
else:
tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
tail_logprob = tf.nn.log_softmax(tail_logit)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
out.append(logprob_i)
if target is not None:
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
if target is not None:
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(tf.shape(loss), dtype=tf.int64))
out = tf.concat(out, axis=-1)
if target is not None:
if return_mean:
loss = tf.reduce_mean(loss)
# Add the training-time loss value to the layer using `self.add_loss()`.
self.add_loss(loss)
# Log the loss as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.
self.add_metric(loss, name=self.name, aggregation='mean' if return_mean else '')
return out
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
params, tie_projs,
initializer=None, proj_initializer=None,
div_val=1, perms=None, proj_same_dim=True,
scope='adaptive_softmax',
**kwargs):
def _logit(x, W, b, proj):
y = x
if x.shape.ndims == 3:
if proj is not None:
y = tf.einsum('ibd,ed->ibe', y, proj)
return tf.einsum('ibd,nd->ibn', y, W) + b
else:
if proj is not None:
y = tf.einsum('id,ed->ie', y, proj)
return tf.einsum('id,nd->in', y, W) + b
params_W, params_projs = params[0], params[1]
with tf.variable_scope(scope):
if len(cutoffs) == 0:
softmax_b = tf.get_variable('bias', [n_token],
initializer=tf.zeros_initializer())
output = _logit(hidden, params_W, softmax_b, params_projs)
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
logits=output)
nll = tf.reduce_mean(nll)
else:
total_loss, total_cnt = 0, 0
cutoff_ends = [0] + cutoffs + [n_token]
for i in range(len(cutoff_ends) - 1):
with tf.variable_scope('cutoff_{}'.format(i)):
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
cur_d_embed = d_embed // (div_val ** i)
if div_val == 1:
cur_W = params_W[l_idx: r_idx]
else:
cur_W = params_W[i]
cur_b = tf.get_variable('b', [r_idx - l_idx],
initializer=tf.zeros_initializer())
if tie_projs[i]:
if div_val == 1:
cur_proj = params_projs
else:
cur_proj = params_projs[i]
else:
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
cur_proj = None
else:
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
initializer=proj_initializer)
if i == 0:
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
initializer=tf.zeros_initializer())
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
initializer=tf.zeros_initializer())
cur_W = tf.concat([cur_W, cluster_W], 0)
cur_b = tf.concat([cur_b, cluster_b], 0)
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
head_target = kwargs.get("head_target")
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=head_target,
logits=head_logit)
masked_loss = head_nll * perms[i]
total_loss += tf.reduce_sum(masked_loss)
total_cnt += tf.reduce_sum(perms[i])
# head_logprob = tf.nn.log_softmax(head_logit)
# final_logprob = head_logprob * perms[i][:, :, None]
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
# total_cnt += tf.reduce_sum(perms[i])
else:
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
perms[i])
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.to_int32(tail_target),
logits=tail_logit)
sum_nll = cur_head_nll + tail_nll
mask = tf.reduce_sum(perms[i], [0, 1])
masked_loss = sum_nll * mask
total_loss += tf.reduce_sum(masked_loss)
total_cnt += tf.reduce_sum(mask)
nll = total_loss / total_cnt
return nll
\ No newline at end of file
...@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self.ffns = [] self.ffns = []
self.layer_norm2 = [] self.layer_norm2 = []
# if self.is_decoder: # if self.is_decoder:
# self.layer_norm15 = tf.keras.layers.LayerList() # self.layer_norm15 = []
# self.encoder_attn = tf.keras.layers.LayerList() # self.encoder_attn = []
for i in range(self.n_layers): for i in range(self.n_layers):
self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions_._{}'.format(i))) self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions_._{}'.format(i)))
......
...@@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module): ...@@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module):
return output return output
class RelPartialLearnableMultiHeadAttn(nn.Module):
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None, head_mask=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
outputs = [h + attn_out]
else:
##### residual connection + layer normalization
outputs = [self.layer_norm(h + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False): r_r_bias=None, r_w_bias=None, output_attentions=False):
super(RelMultiHeadAttn, self).__init__() super(RelPartialLearnableMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.n_head = n_head self.n_head = n_head
...@@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module): ...@@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module):
self.r_r_bias = r_r_bias self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias self.r_w_bias = r_w_bias
def _parallelogram_mask(self, h, w, left=False): self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m,:m] = torch.triu(mask[:m,:m])
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:,:,None,None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False): def _rel_shift(self, x):
zero_pad_shape = (x.size(0), 1) + x.size()[2:] zero_pad_shape = (x.size(0), 1) + x.size()[2:]
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype) zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1) x_padded = torch.cat([zero_pad, x], dim=1)
...@@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module): ...@@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module):
x = x_padded[1:].view_as(x) x = x_padded[1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
return x return x
def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None): def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
...@@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return outputs return outputs
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelPartialLearnableDecoderLayer(nn.Module): class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module):
return outputs return outputs
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False): sample_softmax=False):
...@@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
if hasattr(m, 'r_bias'): if hasattr(m, 'r_bias'):
self._init_bias(m.r_bias) self._init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens):
pass
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_ `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
...@@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_r_bias=None if config.untie_r else self.r_r_bias, r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions) output_attentions=self.output_attentions)
) )
elif config.attn_type == 1: # learnable embeddings else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
for i in range(config.n_layer): raise NotImplementedError # Removed them to avoid maintaining dead code
self.layers.append(
RelLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer):
self.layers.append(
DecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
self.same_length = config.same_length self.same_length = config.same_length
self.clamp_len = config.clamp_len self.clamp_len = config.clamp_len
if self.attn_type == 0: # default attention if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model) self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 1: # learnable else: # learnable embeddings and absolute embeddings
self.r_emb = nn.Parameter(torch.FloatTensor( raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.FloatTensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.FloatTensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.init_weights() self.init_weights()
...@@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems return new_mems
def _forward(self, dec_inp, mems=None, head_mask=None): def forward(self, input_ids, mems=None, head_mask=None):
qlen, bsz = dec_inp.size() # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = input_ids.transpose(0, 1).contiguous()
if mems is None:
mems = self.init_mems(input_ids)
qlen, bsz = input_ids.size()
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
word_emb = self.word_emb(dec_inp) word_emb = self.word_emb(input_ids)
mlen = mems[0].size(0) if mems is not None else 0 mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out = layer_outputs[0] core_out = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
attentions.append(layer_outputs[1]) attentions.append(layer_outputs[1])
elif self.attn_type == 1: # learnable else: # learnable embeddings and absolute embeddings
core_out = self.drop(word_emb) raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
for i, layer in enumerate(self.layers):
hids.append(core_out)
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 3:
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
core_out = self.drop(core_out) core_out = self.drop(core_out)
...@@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions) outputs.append(attentions)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
def forward(self, input_ids, mems=None, head_mask=None):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = input_ids.transpose(0, 1).contiguous()
if mems is None:
mems = self.init_mems(input_ids)
outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions) return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
......
...@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = TFBertModel(config=config) model = TFBertModel(config=config)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs = {'input_ids': input_ids, inputs = {'input_ids': input_ids,
'attention_mask': input_mask, 'attention_mask': input_mask,
'token_type_ids': token_type_ids} 'token_type_ids': token_type_ids}
sequence_output, pooled_output = model(inputs) sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
inputs = [input_ids, input_mask] inputs = [input_ids, input_mask]
sequence_output, pooled_output = model(inputs) sequence_output, pooled_output = model(inputs)
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import random
import shutil
import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from pytorch_transformers import TransfoXLConfig, is_tf_available
if is_tf_available():
import tensorflow as tf
from pytorch_transformers.modeling_tf_transfo_xl import (TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
class TFTransfoXLModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
mem_len=30,
clamp_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
d_embed=32,
num_attention_heads=4,
d_head=8,
d_inner=128,
div_val=2,
num_hidden_layers=5,
scope=None,
seed=1,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.clamp_len = clamp_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.hidden_size = hidden_size
self.d_embed = d_embed
self.num_attention_heads = num_attention_heads
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.num_hidden_layers = num_hidden_layers
self.scope = scope
self.seed = seed
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
cutoffs=self.cutoffs,
d_model=self.hidden_size,
d_embed=self.d_embed,
n_head=self.num_attention_heads,
d_head=self.d_head,
d_inner=self.d_inner,
div_val=self.div_val,
n_layer=self.num_hidden_layers)
return (config, input_ids_1, input_ids_2, lm_labels)
def set_seed(self):
random.seed(self.seed)
tf.random.set_seed(self.seed)
def create_and_check_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
model = TFTransfoXLModel(config)
hidden_states_1, mems_1 = model(input_ids_1)
inputs = {'input_ids': input_ids_2,
'mems': mems_1}
hidden_states_2, mems_2 = model(inputs)
result = {
"hidden_states_1": hidden_states_1.numpy(),
"mems_1": [mem.numpy() for mem in mems_1],
"hidden_states_2": hidden_states_2.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
}
self.parent.assertListEqual(
list(result["hidden_states_1"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(result["hidden_states_2"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TFTransfoXLLMHeadModel(config)
lm_logits_1, mems_1 = model(input_ids_1)
inputs = {'input_ids': input_ids_1,
'labels': lm_labels}
_, mems_1 = model(inputs)
lm_logits_2, mems_2 = model([input_ids_2, mems_1])
inputs = {'input_ids': input_ids_1,
'mems': mems_1,
'labels': lm_labels}
_, mems_2 = model(inputs)
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"lm_logits_1": lm_logits_1.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
"lm_logits_2": lm_logits_2.numpy(),
}
self.parent.assertListEqual(
list(result["lm_logits_1"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["lm_logits_2"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
def setUp(self):
self.model_tester = TFTransfoXLModelTest.TFTransfoXLModelTester(self)
self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_transfo_xl_model(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_model(*config_and_inputs)
def test_transfo_xl_lm_head(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment