Commit c2ea5aef authored by thomwolf's avatar thomwolf
Browse files

work in progress on xlnet

parent de713fa9
...@@ -126,6 +126,16 @@ def swish(x): ...@@ -126,6 +126,16 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(1, bsz, 1)
return pos_emb
class XLNetBaseConfig(object): class XLNetBaseConfig(object):
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
...@@ -165,15 +175,14 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -165,15 +175,14 @@ class XLNetConfig(XLNetBaseConfig):
""" """
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file,
d_model=768, d_model=1024,
n_layer=12, n_layer=24,
n_head=12, n_head=16,
d_inner=3072, d_inner=4096,
ff_activation="gelu", ff_activation="gelu",
untie_r=True, untie_r=True,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12): layer_norm_eps=1e-12):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
...@@ -197,8 +206,6 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -197,8 +206,6 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings: The maximum sequence length that this model might max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048). (e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`XLNetModel`.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm. layer_norm_eps: The epsilon used by LayerNorm.
...@@ -214,11 +221,12 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -214,11 +221,12 @@ class XLNetConfig(XLNetBaseConfig):
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
assert d_model % n_head == 0
self.d_head = d_model // n_head
self.ff_activation = ff_activation self.ff_activation = ff_activation
self.d_inner = d_inner self.d_inner = d_inner
self.untie_r = untie_r self.untie_r = untie_r
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
else: else:
...@@ -233,8 +241,8 @@ class XLNetRunConfig(XLNetBaseConfig): ...@@ -233,8 +241,8 @@ class XLNetRunConfig(XLNetBaseConfig):
We store them separately from XLNetConfig for flexibility. We store them separately from XLNetConfig for flexibility.
""" """
def __init__(self, def __init__(self,
dropout, dropout=0.1,
dropatt, dropatt=0.1,
init="normal", init="normal",
init_range=0.1, init_range=0.1,
init_std=0.02, init_std=0.02,
...@@ -278,12 +286,12 @@ try: ...@@ -278,12 +286,12 @@ try:
except ImportError: except ImportError:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
class XLNetLayerNorm(nn.Module): class XLNetLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12): def __init__(self, d_model, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root). """Construct a layernorm module in the TF style (epsilon inside the square root).
""" """
super(XLNetLayerNorm, self).__init__() super(XLNetLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(hidden_size)) self.bias = nn.Parameter(torch.zeros(d_model))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, x): def forward(self, x):
...@@ -292,6 +300,220 @@ except ImportError: ...@@ -292,6 +300,220 @@ except ImportError:
x = (x - u) / torch.sqrt(s + self.variance_epsilon) x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias return self.weight * x + self.bias
class XLNetRelativeAttention(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetRelativeAttention, self).__init__()
self.output_attentions = output_attentions
if config.d_model % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.d_model, config.num_attention_heads))
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.n_head = config.num_attention_heads
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.q = nn.Parameter(torch.Tensor(config.d_model, self.n_head, self.d_head))
self.k = nn.Parameter(torch.Tensor(config.d_model, self.n_head, self.d_head))
self.v = nn.Parameter(torch.Tensor(config.d_model, self.n_head, self.d_head))
self.o = nn.Parameter(torch.Tensor(config.d_model, self.n_head, self.d_head))
self.r = nn.Parameter(torch.Tensor(config.d_model, self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_s_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.Tensor(self.n_head, 2, self.d_head))
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads):
raise NotImplementedError
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None):
"""Core relative positional attention operations."""
# content based attention score
ac = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_w_bias, k_head_h)
# position based attention score
bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
bd = rel_shift(bd, klen=torch.shape(ac)[1])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = torch.einsum('ibnd,snd->ibns', q_head + self.r_s_bias, self.seg_embed)
ef = torch.einsum('ijbs,ibns->ijbn', seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropout(attn_prob)
# attention output
attn_vec = torch.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
return attn_vec
def post_attention(self, h, attn_vec, residual=True):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out = torch.einsum('ibnd,hnd->ibh', attn_vec, self.o)
attn_out = self.dropout(attn_out)
if residual:
attn_out = attn_out + h
output = self.LayerNorm(attn_out)
return output
def forward(self, h, g,
attn_mask_h, attn_mask_g,
r, seg_mat,
mems=None, target_mapping=None, head_mask=None):
if g is not None:
###### Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content-based key head
k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k)
# content-based value head
v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v)
# position-based key head
k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r)
##### h-stream
# content-stream query head
q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
# post processing
output_h = self.post_attention(h, attn_vec_h)
##### g-stream
# query-stream query head
q_head_g = torch.einsum('ibh,hnd->ibnd', g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g)
attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g)
# post processing
output_g = self.post_attention(g, attn_vec_g)
attention_output = output_h, output_g
else:
###### Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content heads
q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q)
k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k)
v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v)
# positional heads
k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r)
# core attention ops
attn_vec = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
# post processing
attention_output = self.post_attention(h, attn_vec)
# Mask heads if we want to
# if head_mask is not None:
# attention_probs = attention_probs * head_mask
# context_layer = torch.matmul(attention_probs, value_layer)
# if self.keep_multihead_output:
# self.multihead_output = context_layer
# self.multihead_output.retain_grad()
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# context_layer = context_layer.view(*new_context_layer_shape)
# if self.output_attentions:
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return attention_output
class XLNetFeedForward(nn.Module):
def __init__(self, config):
super(XLNetFeedForward, self).__init__()
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or (sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)):
self.activation_function = ACT2FN[config.ff_activation]
else:
self.activation_function = config.ff_activation
def forward(self, hidden_states, input_tensor):
hidden_states = self.layer_1(hidden_states)
hidden_states = self.activation_function(hidden_states)
hidden_states = self.layer_2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class XLNetLayer(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetLayer, self).__init__()
self.output_attentions = output_attentions
self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
def forward(self, output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat, r, seg_mat,
two_streams=False, mems=None, target_mapping=None, head_mask=None):
output_h, output_g = self.rel_attn(output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat,
mems=mems, target_mapping=target_mapping, head_mask=head_mask)
if two_streams:
output_g = self.ff(output_g)
output_h = self.ff(output_h)
# if self.output_attentions:
# return attentions, layer_output
return output_h, output_g
class XLNetPreTrainedModel(nn.Module): class XLNetPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
...@@ -445,6 +667,228 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -445,6 +667,228 @@ class XLNetPreTrainedModel(nn.Module):
class XLNetModel(XLNetPreTrainedModel): class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetModel, self).__init__()
self.output_attentions = output_attentions
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
layer = XLNetLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
@classmethod
def _create_mask(qlen, mlen, dtype=torch.float, same_length=False):
"""create causal attention mask."""
attn_mask = torch.ones([qlen, qlen], dtype=dtype)
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
return new_mem.detach()
def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=torch.float):
"""create relative positional encoding."""
freq_seq = torch.zrange(0, d_model, 2.0, dtype=dtype)
inv_freq = 1 / (10000 ** (freq_seq / self.config.d_model))
if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif self.attn_type == 'uni':
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)
return pos_emb
def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen, bsz = inp_k.shape
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
attn_mask = _create_mask(qlen, mlen, inp_k.dtype, self.same_length)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
# all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz], dtype=data_mask.dtype, device=data_mask.device)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = (attn_mask > 0).float()
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=tf_float)
else:
non_tgt_mask = None
##### Word embedding
word_emb_k = self.word_embedding(inp_k)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, 1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
##### Segment embedding
if seg_id is not None:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long)
cat_ids = torch.cat([mem_pad, seg_id], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (seg_id[:, None] != cat_ids[None, :]).long()
# seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
else:
seg_mat = None
##### Positional encoding
pos_emb = relative_positional_encoding(qlen, klen, bsz=bsz, dtype=inp_k.dtype)
pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
new_mems = []
if mems is None:
mems = [None] * len(self.layer)
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems.append(self.cache_mem(output_h, mems[i]))
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
output = self.dropout(output_g if output_g is not None else output_h)
return output
class XLNetLMHeadModel(XLNetPreTrainedModel):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
Params: Params:
...@@ -473,10 +917,10 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -473,10 +917,10 @@ class XLNetModel(XLNetPreTrainedModel):
`encoded_layers`: controled by `output_all_encoded_layers` argument: `encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, d_model],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size], to the last attention block of shape [batch_size, sequence_length, d_model],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a `pooled_output`: a torch.FloatTensor of size [batch_size, d_model] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper). input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
...@@ -487,16 +931,30 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -487,16 +931,30 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768, config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config) model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, run_config, output_attentions=False, keep_multihead_output=False):
super(XLNetModel, self).__init__(config) super(XLNetLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.attn_type = run_config.attn_type
self.same_length = run_config.same_length
self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, self.d_model))
self.transformer = XLNetModel(config,
output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.dropout = nn.Dropout(config.dropout)
# Tie weights
if config.tie_weight:
self.lm_loss.weight = self.word_embedding.weight
self.apply(self.init_xlnet_weights) self.apply(self.init_xlnet_weights)
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
...@@ -512,54 +970,56 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -512,54 +970,56 @@ class XLNetModel(XLNetPreTrainedModel):
""" """
return [layer.attention.self.multihead_output for layer in self.encoder.layer] return [layer.attention.self.multihead_output for layer in self.encoder.layer]
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None): def forward(self, inp_k, seg_id=None, input_mask=None,
if attention_mask is None: mems=None, perm_mask=None, target_mapping=None, inp_q=None,
attention_mask = torch.ones_like(input_ids) output_all_encoded_layers=True, head_mask=None):
if token_type_ids is None: """
token_type_ids = torch.zeros_like(input_ids) Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# We create a 3D attention mask from a 2D tensor mask. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# Sizes are [batch_size, 1, 1, to_seq_length] input_mask: float32 Tensor in shape [len, bsz], the input mask.
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 0 for real tokens and 1 for padding.
# this attention mask is more simple than the triangular masking of causal attention mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. from previous batches. The length of the list equals n_layer.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for If perm_mask[i, j, k] = 0, i attend to j in batch k;
# masked positions, this operation will create a tensor which is 0.0 for if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
# positions we want to attend and -10000.0 for masked positions. If None, each position attends to all the others.
# Since we are adding it to the raw scores before the softmax, this is target_mapping: float32 Tensor in shape [num_predict, len, bsz].
# effectively the same as removing these entirely. If target_mapping[i, j, k] = 1, the i-th predict in batch k is
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility on the j-th token.
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 Only used during pretraining for partial prediction.
Set to None during finetuning.
# Prepare head mask if needed inp_q: float32 Tensor in shape [len, bsz].
# 1.0 in head_mask indicate we keep the head 1 for tokens with losses and 0 for tokens without losses.
# attention_probs has shape bsz x n_heads x N x N Only used during pretraining for two-stream attention.
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] Set to None during finetuning.
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, token_type_ids) mem_len: int, the number of tokens to cache.
encoded_layers = self.encoder(embedding_output, reuse_len: int, the number of tokens in the currect batch to be cached
extended_attention_mask, and reused in the future.
output_all_encoded_layers=output_all_encoded_layers, bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output, new_mems = self.transformer(output_h, non_tgt_mask, r, seg_mat,
output_g=output_g, attn_mask_g=attn_mask,
mems=mems, target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask)
if self.output_attentions:
all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output
logits = self.lm_loss(output)
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return output, new_mems
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment