Commit 45709d75 authored by thomwolf's avatar thomwolf
Browse files

model running with simple inputs

parent b407972e
......@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead,
load_tf_weights_in_gpt2)
from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
load_tf_weights_in_xlnet)
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
......
......@@ -21,13 +21,13 @@ from __future__ import print_function
import argparse
import torch
from pytorch_pretrained_bert.modeling_xlnet import XLNetConfig, XLNetRunConfig, XLNetModel, load_tf_weights_in_xlnet
from pytorch_pretrained_bert.modeling_xlnet import XLNetConfig, XLNetRunConfig, XLNetLMHeadModel, load_tf_weights_in_xlnet
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = XLNetModel(config)
model = XLNetLMHeadModel(config)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, tf_checkpoint_path)
......
......@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
......@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
......@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
......@@ -29,6 +29,7 @@ from io import open
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
......@@ -126,32 +127,27 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(1, bsz, 1)
return pos_emb
class XLNetBaseConfig(object):
@classmethod
def from_dict(cls, json_object):
"""Constructs a `XLNetConfig` from a Python dictionary of parameters."""
config = XLNetConfig(vocab_size_or_config_json_file=-1)
"""Constructs a `XLNetBaseConfig` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `XLNetConfig` from a json file of parameters."""
"""Constructs a `XLNetBaseConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def update(self, other):
dict_b = other.to_dict()
for key, value in dict_b.items():
self.__dict__[key] = value
def __repr__(self):
return str(self.to_json_string())
......@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig):
d_inner=4096,
ff_activation="gelu",
untie_r=True,
attn_type="bi",
max_position_embeddings=512,
initializer_range=0.02,
......@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig):
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
......@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig):
self.ff_activation = ff_activation
self.d_inner = d_inner
self.untie_r = untie_r
self.attn_type = attn_type
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
......@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetRelativeAttention, self).__init__()
self.output_attentions = output_attentions
if config.d_model % config.num_attention_heads != 0:
if config.d_model % config.n_head != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.d_model, config.num_attention_heads))
"heads (%d)" % (config.d_model, config.n_head))
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.n_head = config.num_attention_heads
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
......@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module):
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_s_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.Tensor(self.n_head, 2, self.d_head))
self.seg_embed = nn.Parameter(torch.Tensor(2, self.n_head, self.d_head))
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout)
......@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module):
def prune_heads(self, heads):
raise NotImplementedError
@staticmethod
def rel_shift(x, klen=-1):
"""perform relative shift to form the relative attention score."""
x_size = x.shape
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
x = x[:, 0:klen, :, :]
return x
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None):
"""Core relative positional attention operations."""
......@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module):
# position based attention score
bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
bd = rel_shift(bd, klen=torch.shape(ac)[1])
bd = self.rel_shift(bd, klen=ac.shape[1])
# segment based attention score
if seg_mat is None:
......@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module):
# post processing
output_g = self.post_attention(g, attn_vec_g)
attention_output = output_h, output_g
else:
###### Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
......@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module):
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
# post processing
attention_output = self.post_attention(h, attn_vec)
output_h = self.post_attention(h, attn_vec)
output_g = None
# Mask heads if we want to
......@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module):
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return attention_output
return output_h, output_g
class XLNetFeedForward(nn.Module):
def __init__(self, config):
......@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module):
else:
self.activation_function = config.ff_activation
def forward(self, hidden_states, input_tensor):
hidden_states = self.layer_1(hidden_states)
hidden_states = self.activation_function(hidden_states)
hidden_states = self.layer_2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def forward(self, inp):
output = inp
output = self.layer_1(output)
output = self.activation_function(output)
output = self.dropout(output)
output = self.layer_2(output)
output = self.dropout(output)
output = self.LayerNorm(output + inp)
return output
class XLNetLayer(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
......@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module):
def forward(self, output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat, r, seg_mat,
two_streams=False, mems=None, target_mapping=None, head_mask=None):
r, seg_mat,
mems=None, target_mapping=None, head_mask=None):
output_h, output_g = self.rel_attn(output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat,
mems=mems, target_mapping=target_mapping, head_mask=head_mask)
if two_streams:
if output_g is not None:
output_g = self.ff(output_g)
output_h = self.ff(output_h)
......@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module):
"""
def __init__(self, config, *inputs, **kwargs):
super(XLNetPreTrainedModel, self).__init__()
if not isinstance(config, XLNetConfig):
if not isinstance(config, XLNetBaseConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"Parameter config in `{}(config)` should be an instance of class `XLNetBaseConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
......@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module):
class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetModel, self).__init__()
super(XLNetModel, self).__init__(config)
self.output_attentions = output_attentions
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
layer = XLNetLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.d_model = config.d_model
self.same_length = config.same_length
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
@classmethod
def _create_mask(qlen, mlen, dtype=torch.float, same_length=False):
"""create causal attention mask."""
attn_mask = torch.ones([qlen, qlen], dtype=dtype)
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
layer = XLNetLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
def create_mask(self, qlen, mlen):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = torch.ones([qlen, qlen])
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters()))
return ret
def cache_mem(self, curr_out, prev_mem):
......@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel):
return new_mem.detach()
def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=torch.float):
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(-1, bsz, -1)
return pos_emb
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = torch.zrange(0, d_model, 2.0, dtype=dtype)
inv_freq = 1 / (10000 ** (freq_seq / self.config.d_model))
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen
......@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel):
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=dtype)
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype)
fwd_pos_seq = torch.arange(beg, end, -1.0)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, inp_k, seg_id=None, input_mask=None,
def forward(self, word_emb_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
word_emb_k: float32 Tensor in shape [len, bsz, d_model], the input token embeddings.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
perm_mask: [optional] float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
target_mapping: [optional] float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
inp_q: [optional] float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
......@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen, bsz = inp_k.shape
qlen, bsz = word_emb_k.shape[0], word_emb_k.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
dtype_float = word_emb_k.dtype
device = word_emb_k.device
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
attn_mask = _create_mask(qlen, mlen, inp_k.dtype, self.same_length)
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
attn_mask = None
......@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel):
if data_mask is not None:
# all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz], dtype=data_mask.dtype, device=data_mask.device)
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
......@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel):
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = (attn_mask > 0).float()
attn_mask = (attn_mask > 0).to(dtype_float)
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=tf_float)
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
else:
non_tgt_mask = None
##### Word embedding
word_emb_k = self.word_embedding(inp_k)
##### Process Word embeddings and prepare h & g hidden states
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, 1)
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
......@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding
if seg_id is not None:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long)
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, seg_id], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (seg_id[:, None] != cat_ids[None, :]).long()
# seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else:
seg_mat = None
##### Positional encoding
pos_emb = relative_positional_encoding(qlen, klen, bsz=bsz, dtype=inp_k.dtype)
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.n_layer
new_mems = []
if mems is None:
......@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel):
new_mems.append(self.cache_mem(output_h, mems[i]))
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
output = self.dropout(output_g if output_g is not None else output_h)
return output
return output, new_mems
class XLNetLMHeadModel(XLNetPreTrainedModel):
......@@ -932,28 +971,27 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, run_config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions
self.attn_type = run_config.attn_type
self.same_length = run_config.same_length
self.attn_type = config.attn_type
self.same_length = config.same_length
self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, self.d_model))
self.transformer = XLNetModel(config,
output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.dropout = nn.Dropout(config.dropout)
# Tie weights
if config.tie_weight:
self.lm_loss.weight = self.word_embedding.weight
self.lm_loss.weight = self.word_embedding.weight
self.apply(self.init_xlnet_weights)
......@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
target=None, output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
......@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output, new_mems = self.transformer(output_h, non_tgt_mask, r, seg_mat,
output_g=output_g, attn_mask_g=attn_mask,
mems=mems, target_mapping=target_mapping,
head_mask=head_mask)
word_emb_k = self.word_embedding(inp_k)
output, new_mems = self.transformer(word_emb_k, seg_id, input_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
logits = self.lm_loss(output)
if target is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
target.view(-1))
return loss, new_mems
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetRunConfig, XLNetModel, XLNetLMHeadModel)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
mem_len=30,
clamp_len=15,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
d_model=32,
n_head=4,
d_inner=128,
n_layer=5,
max_position_embeddings=10,
untie_r=True,
bi_data=False,
same_length=False,
seed=1,
type_vocab_size=2):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.clamp_len = clamp_len
self.reuse_len = reuse_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.d_model = d_model
self.n_head = n_head
self.d_inner = d_inner
self.n_layer = n_layer
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
self.seed = seed
self.type_vocab_size = type_vocab_size
def prepare_config_and_inputs(self):
input_ids_1 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
input_ids_2 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.type_vocab_size)
lm_labels = None
if self.use_labels:
lm_labels = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
d_model=self.d_model,
n_head=self.n_head,
d_inner=self.d_inner,
n_layer=self.n_layer,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings)
run_config = XLNetRunConfig(
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data)
config.update(run_config)
return (config, input_ids_1, input_ids_2, segment_ids, lm_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
hidden_states_1, mems_1 = model(input_ids_1, seg_id=segment_ids)
hidden_states_2, mems_2 = model(input_ids_2, seg_id=segment_ids, mems=mems_1)
outputs = {
"hidden_states_1": hidden_states_1,
"mems_1": mems_1,
"hidden_states_2": hidden_states_2,
"mems_2": mems_2,
}
return outputs
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1)
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
outputs = {
"loss_1": loss_1,
"mems_1a": mems_1a,
"lm_logits_1": lm_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2,
"mems_2a": mems_2a,
"lm_logits_2": lm_logits_2,
"mems_2b": mems_2b,
}
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(
list(result["loss_1"].size()),
[self.seq_length, self.batch_size])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual(
list(result["loss_2"].size()),
[self.seq_length, self.batch_size])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self))
def test_config_to_json_string(self):
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_model"], 37)
def test_config_to_json_file(self):
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = XLNetConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed()
output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result)
tester.set_seed()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment