Commit 15b70338 authored by thomwolf's avatar thomwolf
Browse files

adding squad model to xlnet and xlm

parent fbe04423
...@@ -25,7 +25,7 @@ from io import open ...@@ -25,7 +25,7 @@ from io import open
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss, functional as F
from .file_utils import cached_path from .file_utils import cached_path
...@@ -301,22 +301,189 @@ class Conv1D(nn.Module): ...@@ -301,22 +301,189 @@ class Conv1D(nn.Module):
return x return x
class SequenceSummary(nn.Module): class PoolerStartLogits(nn.Module):
""" Compute SQuAD start_logits from sequence hidden states. """
def __init__(self, config):
super(PoolerStartLogits, self).__init__()
self.dense = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, p_mask=None):
""" Args:
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerEndLogits(nn.Module):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def __init__(self, config):
super(PoolerEndLogits, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
x = self.LayerNorm(x)
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerAnswerClass(nn.Module):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def __init__(self, config):
super(PoolerAnswerClass, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
"""
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
x = self.dense_1(x).squeeze(-1)
return x
class SQuADHead(nn.Module):
""" A SQuAD head inspired by XLNet.
Compute
"""
def __init__(self, config): def __init__(self, config):
""" Compute a single vector summary of a sequence hidden states according to various possibilities: super(SQuADHead, self).__init__()
Args of the config class: self.start_n_top = config.start_n_top
summary_type: self.end_n_top = config.end_n_top
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert) self.start_logits = PoolerStartLogits(config)
- 'mean' => take the mean of all tokens hidden states self.end_logits = PoolerEndLogits(config)
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2) self.answer_class = PoolerAnswerClass(config)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction def forward(self, hidden_states, start_positions=None, end_positions=None,
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size) cls_index=None, is_impossible=None, p_mask=None):
summary_activation: """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
'tanh' => add a tanh activation to the output
None => no activation
""" """
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else:
outputs = (total_loss, start_logits, end_logits) + outputs
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
return outputs
class SequenceSummary(nn.Module):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
summary_activation:
'tanh' => add a tanh activation to the output
None => no activation
"""
def __init__(self, config):
super(SequenceSummary, self).__init__() super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
......
...@@ -35,7 +35,8 @@ from torch.nn import functional as F ...@@ -35,7 +35,8 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig): ...@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
n_langs=1, n_langs=1,
max_position_embeddings=512, max_position_embeddings=512,
embed_init_std=2048 ** -0.5, embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12,
init_std=0.02, init_std=0.02,
summary_type="last",
use_proj=True,
bos_index=0, bos_index=0,
eos_index=1, eos_index=1,
pad_index=2, pad_index=2,
unk_index=3, unk_index=3,
mask_index=5, mask_index=5,
is_encoder=True, is_encoder=True,
finetuning_task=None,
num_labels=2,
summary_type='last',
summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs): **kwargs):
"""Constructs XLMConfig. """Constructs XLMConfig.
...@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig): ...@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
self.causal = causal self.causal = causal
self.asm = asm self.asm = asm
self.n_langs = n_langs self.n_langs = n_langs
self.summary_type = summary_type self.layer_norm_eps = layer_norm_eps
self.use_proj = use_proj
self.bos_index = bos_index self.bos_index = bos_index
self.eos_index = eos_index self.eos_index = eos_index
self.pad_index = pad_index self.pad_index = pad_index
...@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig): ...@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.embed_init_std = embed_init_std self.embed_init_std = embed_init_std
self.init_std = init_std self.init_std = init_std
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig): ...@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
return self.n_layers return self.n_layers
def Embedding(num_embeddings, embedding_dim, padding_idx=None, config=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
if config is not None and config.embed_init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.embed_init_std)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, config=None):
m = nn.Linear(in_features, out_features, bias)
if config is not None and config.init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.init_std)
if bias:
nn.init.constant_(m.bias, 0.)
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.)
return m
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([ position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
...@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module): ...@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config): def __init__(self, n_heads, dim, config):
super().__init__() super(MultiHeadAttention, self).__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID) self.layer_id = next(MultiHeadAttention.NEW_ID)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.dim = dim self.dim = dim
...@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module): ...@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
self.q_lin = Linear(dim, dim, config=config) self.q_lin = nn.Linear(dim, dim)
self.k_lin = Linear(dim, dim, config=config) self.k_lin = nn.Linear(dim, dim)
self.v_lin = Linear(dim, dim, config=config) self.v_lin = nn.Linear(dim, dim)
self.out_lin = Linear(dim, dim, config=config) self.out_lin = nn.Linear(dim, dim)
def prune_heads(self, heads): def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads attention_head_size = self.dim // self.n_heads
...@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module): ...@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
class TransformerFFN(nn.Module): class TransformerFFN(nn.Module):
def __init__(self, in_dim, dim_hidden, out_dim, config): def __init__(self, in_dim, dim_hidden, out_dim, config):
super().__init__() super(TransformerFFN, self).__init__()
self.dropout = config.dropout self.dropout = config.dropout
self.lin1 = Linear(in_dim, dim_hidden, config=config) self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = Linear(dim_hidden, out_dim, config=config) self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu self.act = gelu if config.gelu_activation else F.relu
def forward(self, input): def forward(self, input):
...@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "xlm" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs) super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights. """
""" if isinstance(module, nn.Embedding):
if isinstance(module, (nn.Linear, nn.Embedding)): if self.config is not None and self.config.embed_init_std is not None:
# Weights are initialized in module instantiation (see above) nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
pass if isinstance(module, nn.Linear):
if self.config is not None and self.config.init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, 0.)
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
...@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel): ...@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads' assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
# embeddings # embeddings
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim, config=config) self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1: if config.n_langs > 1:
self.lang_embeddings = Embedding(self.n_langs, self.dim, config=config) self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index, config=config) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
# transformer layers # transformer layers
self.attentions = nn.ModuleList() self.attentions = nn.ModuleList()
...@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel): ...@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
for _ in range(self.n_layers): for _ in range(self.n_layers):
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config)) self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# if self.is_decoder: # if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
self.apply(self.init_weights)
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
...@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module): ...@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
Prediction layer (cross_entropy or adaptive_softmax). Prediction layer (cross_entropy or adaptive_softmax).
""" """
def __init__(self, config): def __init__(self, config):
super().__init__() super(XLMPredLayer, self).__init__()
self.asm = config.asm self.asm = config.asm
self.n_words = config.n_words self.n_words = config.n_words
self.pad_index = config.pad_index self.pad_index = config.pad_index
dim = config.emb_dim dim = config.emb_dim
if config.asm is False: if config.asm is False:
self.proj = Linear(dim, config.n_words, bias=True) self.proj = nn.Linear(dim, config.n_words, bias=True)
else: else:
self.proj = nn.AdaptiveLogSoftmaxWithLoss( self.proj = nn.AdaptiveLogSoftmaxWithLoss(
in_features=dim, in_features=dim,
...@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module): ...@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
head_bias=True, # default is False head_bias=True, # default is False
) )
def forward(self, x, y, get_scores=False): def forward(self, x, y=None):
""" Compute the loss, and optionally the scores.
""" """
Compute the loss, and optionally the scores. outputs = ()
"""
assert (y == self.pad_index).sum().item() == 0
if self.asm is False: if self.asm is False:
scores = self.proj(x).view(-1, self.n_words) scores = self.proj(x).view(-1, self.n_words)
loss = F.cross_entropy(scores, y, reduction='elementwise_mean') outputs = (scores,) + outputs
if y is not None:
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
outputs = (loss,) + outputs
else: else:
_, loss = self.proj(x, y) scores = self.proj.log_prob(x)
scores = self.proj.log_prob(x) if get_scores else None outputs = (scores,) + outputs
if y is not None:
return scores, loss _, loss = self.proj(x, y)
outputs = (loss,) + outputs
def get_scores(self, x):
"""
Compute scores.
"""
assert x.dim() == 2
return self.proj.log_prob(x) if self.asm else self.proj(x)
return outputs
class XLMWithLMHeadModel(XLMPreTrainedModel): class XLMWithLMHeadModel(XLMPreTrainedModel):
...@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(XLMWithLMHeadModel, self).__init__(config) super(XLMWithLMHeadModel, self).__init__(config)
self.torchscript = config.torchscript
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
...@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def tie_weights(self): def tie_weights(self):
""" Make sure we are sharing the embeddings """ Make sure we are sharing the embeddings
""" """
self.pred_layer.proj.weight = self.transformer.embeddings.weight if self.torchscript:
self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
else:
self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None): attention_mask=None, cache=None, labels=None, head_mask=None):
...@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.pred_layer(output, labels) outputs = self.pred_layer(output, labels)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if labels is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
labels.view(-1))
outputs = [loss] + outputs
outputs = [logits] + outputs
return outputs return outputs
class XLMSequenceSummary(nn.Module):
def __init__(self, config):
super(XLMSequenceSummary, self).__init__()
self.summary_type = config.summary_type
if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model)
else:
self.summary = None
if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if self.summary_type == 'last':
output = hidden_states[:, -1]
elif self.summary_type == 'first':
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=1)
elif summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.activation(output)
output = self.dropout(output)
return output
class XLMForSequenceClassification(XLMPreTrainedModel): class XLMForSequenceClassification(XLMPreTrainedModel):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding"). """XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
...@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(XLMForSequenceClassification, self).__init__(config) super(XLMForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.sequence_summary = XLMSequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
cache=None, labels=None, head_mask=None): attention_mask=None, cache=None, labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(output)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
...@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else: else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
outputs = [logits] + outputs
return outputs return outputs
...@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
super(XLMForQuestionAnswering, self).__init__(config) super(XLMForQuestionAnswering, self).__init__(config)
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = SQuADHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, cache=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
labels=None, head_mask=None): attention_mask=None, cache=None, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1) outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
start_logits = start_logits.squeeze(-1) cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)
end_logits = end_logits.squeeze(-1)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs
outputs = [start_logits, end_logits] + outputs
return outputs return outputs
...@@ -32,8 +32,8 @@ from torch.nn import functional as F ...@@ -32,8 +32,8 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
PretrainedConfig, PreTrainedModel, SequenceSummary) SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig): ...@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig):
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False, same_length=False,
finetuning_task=None, finetuning_task=None,
num_labels=2, num_labels=2,
summary_type='last', summary_type='last',
summary_use_proj=True, summary_use_proj=True,
summary_activation='tanh', summary_activation='tanh',
summary_dropout=0.1, summary_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs): **kwargs):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
...@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig): ...@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
self.summary_use_proj = summary_use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_dropout = summary_dropout self.summary_dropout = summary_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(XLNetForSequenceClassification, self).__init__(config) super(XLNetForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
...@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
class XLNetForQuestionAnswering(XLNetPreTrainedModel): class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""XLNet model for Question Answering (span extraction). """ XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits the sequence output that computes start_logits and end_logits
...@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(XLNetForQuestionAnswering, self).__init__(config) super(XLNetForQuestionAnswering, self).__init__(config)
self.start_n_top = config.start_n_top
self.end_n_top = config.end_n_top
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.d_model, config.num_labels) self.start_logits = PoolerStartLogits(config)
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None, head_mask=None): start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
head_mask=None):
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask) mems, perm_mask, target_mapping, inp_q, head_mask)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask)
logits = self.qa_outputs(transformer_outputs[0]) outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
start_logits, end_logits = logits.split(1, dim=-1) if start_positions is not None and end_positions is not None:
start_logits = start_logits.squeeze(-1) # If we are on multi-GPU, let's remove the dimension added by batch splitting
end_logits = end_logits.squeeze(-1) for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it # during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
if start_positions is not None and end_positions is not None: loss_fct = CrossEntropyLoss()
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions) if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else:
outputs = (total_loss, start_logits, end_logits) + outputs
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
return outputs
...@@ -16,11 +16,7 @@ from __future__ import absolute_import ...@@ -16,11 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json
import random
import shutil
import pytest import pytest
import torch import torch
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
...@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase): ...@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
summary_type="last", summary_type="last",
use_proj=True, use_proj=True,
scope=None, scope=None,
all_model_classes = (XLMModel,), # , XLMForSequenceClassification, XLMForTokenClassification), all_model_classes = (XLMModel, XLMWithLMHeadModel,
XLMForQuestionAnswering, XLMForSequenceClassification), # , XLMForSequenceClassification, XLMForTokenClassification),
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase): ...@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_lengths = None input_lengths = None
if self.use_input_lengths: if self.use_input_lengths:
...@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase): ...@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None is_impossible_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLMConfig( config = XLMConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase): ...@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj) use_proj=self.use_proj)
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMModel(config=config) model = XLMModel(config=config)
model.eval() model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
...@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase): ...@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
# model = XLMForMaskedLM(config=config) model = XLMWithLMHeadModel(config)
# model.eval() model.eval()
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = { loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
# "loss": loss,
# "prediction_scores": prediction_scores, result = {
# } "loss": loss,
# self.parent.assertListEqual( "logits": logits,
# list(result["prediction_scores"].size()), }
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result) self.parent.assertListEqual(
list(result["loss"].size()),
[])
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): self.parent.assertListEqual(
# model = XLMForQuestionAnswering(config=config) list(result["logits"].size()),
# model.eval() [self.batch_size, self.seq_length, self.vocab_size])
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
# result = {
# "loss": loss, def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
# "start_logits": start_logits, model = XLMForQuestionAnswering(config)
# "end_logits": end_logits, model.eval()
# }
# self.parent.assertListEqual( outputs = model(input_ids)
# list(result["start_logits"].size()), start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual( outputs = model(input_ids, start_positions=sequence_labels,
# list(result["end_logits"].size()), end_positions=sequence_labels,
# [self.batch_size, self.seq_length]) cls_index=sequence_labels,
# self.check_loss_output(result) is_impossible=is_impossible_labels,
p_mask=input_mask)
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): outputs = model(input_ids, start_positions=sequence_labels,
# config.num_labels = self.num_labels end_positions=sequence_labels,
# model = XLMForSequenceClassification(config) cls_index=sequence_labels,
# model.eval() is_impossible=is_impossible_labels)
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
# result = { total_loss, start_logits, end_logits, cls_logits = outputs
# "loss": loss,
# "logits": logits, outputs = model(input_ids, start_positions=sequence_labels,
# } end_positions=sequence_labels)
# self.parent.assertListEqual(
# list(result["logits"].size()), total_loss, start_logits, end_logits = outputs
# [self.batch_size, self.num_labels])
# self.check_loss_output(result) result = {
"loss": total_loss,
"start_logits": start_logits,
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): "end_logits": end_logits,
# config.num_labels = self.num_labels "cls_logits": cls_logits,
# model = XLMForTokenClassification(config=config) }
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels) self.parent.assertListEqual(
# result = { list(result["loss"].size()),
# "loss": loss, [])
# "logits": logits, self.parent.assertListEqual(
# } list(result["start_logits"].size()),
# self.parent.assertListEqual( [self.batch_size, self.seq_length])
# list(result["logits"].size()), self.parent.assertListEqual(
# [self.batch_size, self.seq_length, self.num_labels]) list(result["end_logits"].size()),
# self.check_loss_output(result) [self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): [self.batch_size])
# config.num_choices = self.num_choices
# model = XLMForMultipleChoice(config=config)
# model.eval() def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() model = XLMForSequenceClassification(config)
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() model.eval()
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids, (logits,) = model(input_ids)
# multiple_choice_token_type_ids, loss, logits = model(input_ids, labels=sequence_labels)
# multiple_choice_input_lengths,
# choice_labels) result = {
# result = { "loss": loss,
# "loss": loss, "logits": logits,
# "logits": logits, }
# }
# self.parent.assertListEqual( self.parent.assertListEqual(
# list(result["logits"].size()), list(result["loss"].size()),
# [self.batch_size, self.num_choices]) [])
# self.check_loss_output(result) self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths} inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
create_and_check_commons(self, config, inputs_dict) create_and_check_commons(self, config, inputs_dict)
......
...@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
d_inner=128, d_inner=128,
num_hidden_layers=5, num_hidden_layers=5,
max_position_embeddings=10, max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True, untie_r=True,
bi_data=False, bi_data=False,
same_length=False, same_length=False,
...@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase): ...@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.seed = seed self.seed = seed
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.all_model_classes = all_model_classes self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float) perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
...@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase): ...@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
target_mapping[:, 0, -1] = 1.0 # predict last token target_mapping[:, 0, -1] = 1.0 # predict last token
inp_q = target_mapping[:, 0, :].clone() # predict last token inp_q = target_mapping[:, 0, :].clone() # predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs. sequence_labels = None
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
lm_labels = None lm_labels = None
is_impossible_labels = None
if self.use_labels: if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLNetConfig( config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase): ...@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
same_length=self.same_length, same_length=self.same_length,
reuse_len=self.reuse_len, reuse_len=self.reuse_len,
bi_data=self.bi_data, bi_data=self.bi_data,
initializer_range=self.initializer_range) initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels) return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
def set_seed(self): def set_seed(self):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetModel(config) model = XLNetModel(config)
model.eval() model.eval()
_, _ = model(input_ids_1, input_mask=input_mask)
_, _ = model(input_ids_1, attention_mask=input_mask)
_, _ = model(input_ids_1, token_type_ids=segment_ids) _, _ = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1) outputs, mems_1 = model(input_ids_1)
...@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase): ...@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.eval() model.eval()
...@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase): ...@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForQuestionAnswering(config)
model.eval()
outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask)
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, start_logits, end_logits, mems = outputs
result = {
"loss": total_loss,
"start_logits": start_logits,
"end_logits": end_logits,
"cls_logits": cls_logits,
"mems": mems,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForSequenceClassification(config)
model.eval()
logits, mems_1 = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
inputs_dict = {'input_ids': input_ids_1} inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict, test_pruning=False) create_and_check_commons(self, config, inputs_dict, test_pruning=False)
...@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase): ...@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase):
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_lm_head(*config_and_inputs) tester.create_and_check_xlnet_lm_head(*config_and_inputs)
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_commons(*config_and_inputs) tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
@classmethod
def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a tensor with padding on the right (0.0 for )."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = [] tester.set_seed()
for _ in range(total_dims): config_and_inputs = tester.prepare_config_and_inputs()
values.append(rng.randint(0, vocab_size - 1)) tester.create_and_check_xlnet_qa(*config_and_inputs)
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_commons(*config_and_inputs)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment