"git@developer.sourcefind.cn:gaoqiong/flash-attention.git" did not exist on "27f8f890dff58986391b606bc7c181c3b9f5148a"
Commit 15b70338 authored by thomwolf's avatar thomwolf
Browse files

adding squad model to xlnet and xlm

parent fbe04423
......@@ -25,7 +25,7 @@ from io import open
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss, MSELoss, functional as F
from .file_utils import cached_path
......@@ -301,8 +301,174 @@ class Conv1D(nn.Module):
return x
class SequenceSummary(nn.Module):
class PoolerStartLogits(nn.Module):
""" Compute SQuAD start_logits from sequence hidden states. """
def __init__(self, config):
super(PoolerStartLogits, self).__init__()
self.dense = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, p_mask=None):
""" Args:
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerEndLogits(nn.Module):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def __init__(self, config):
super(PoolerEndLogits, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
x = self.LayerNorm(x)
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerAnswerClass(nn.Module):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def __init__(self, config):
super(PoolerAnswerClass, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
"""
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
x = self.dense_1(x).squeeze(-1)
return x
class SQuADHead(nn.Module):
""" A SQuAD head inspired by XLNet.
Compute
"""
def __init__(self, config):
super(SQuADHead, self).__init__()
self.start_n_top = config.start_n_top
self.end_n_top = config.end_n_top
self.start_logits = PoolerStartLogits(config)
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
def forward(self, hidden_states, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
"""
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else:
outputs = (total_loss, start_logits, end_logits) + outputs
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
return outputs
class SequenceSummary(nn.Module):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
......@@ -317,6 +483,7 @@ class SequenceSummary(nn.Module):
'tanh' => add a tanh activation to the output
None => no activation
"""
def __init__(self, config):
super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
......
......@@ -35,7 +35,8 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead)
logger = logging.getLogger(__name__)
......@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
n_langs=1,
max_position_embeddings=512,
embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12,
init_std=0.02,
summary_type="last",
use_proj=True,
bos_index=0,
eos_index=1,
pad_index=2,
unk_index=3,
mask_index=5,
is_encoder=True,
finetuning_task=None,
num_labels=2,
summary_type='last',
summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs):
"""Constructs XLMConfig.
......@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
self.causal = causal
self.asm = asm
self.n_langs = n_langs
self.summary_type = summary_type
self.use_proj = use_proj
self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index
self.eos_index = eos_index
self.pad_index = pad_index
......@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.embed_init_std = embed_init_std
self.init_std = init_std
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
return self.n_layers
def Embedding(num_embeddings, embedding_dim, padding_idx=None, config=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
if config is not None and config.embed_init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.embed_init_std)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, config=None):
m = nn.Linear(in_features, out_features, bias)
if config is not None and config.init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.init_std)
if bias:
nn.init.constant_(m.bias, 0.)
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.)
return m
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
......@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config):
super().__init__()
super(MultiHeadAttention, self).__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID)
self.output_attentions = config.output_attentions
self.dim = dim
......@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
self.dropout = config.attention_dropout
assert self.dim % self.n_heads == 0
self.q_lin = Linear(dim, dim, config=config)
self.k_lin = Linear(dim, dim, config=config)
self.v_lin = Linear(dim, dim, config=config)
self.out_lin = Linear(dim, dim, config=config)
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
......@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
class TransformerFFN(nn.Module):
def __init__(self, in_dim, dim_hidden, out_dim, config):
super().__init__()
super(TransformerFFN, self).__init__()
self.dropout = config.dropout
self.lin1 = Linear(in_dim, dim_hidden, config=config)
self.lin2 = Linear(dim_hidden, out_dim, config=config)
self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu
def forward(self, input):
......@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "xlm"
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Weights are initialized in module instantiation (see above)
pass
""" Initialize the weights. """
if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
if isinstance(module, nn.Linear):
if self.config is not None and self.config.init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, 0.)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
# embeddings
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim, config=config)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1:
self.lang_embeddings = Embedding(self.n_langs, self.dim, config=config)
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index, config=config)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
# transformer layers
self.attentions = nn.ModuleList()
......@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
for _ in range(self.n_layers):
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
self.apply(self.init_weights)
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
......@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
Prediction layer (cross_entropy or adaptive_softmax).
"""
def __init__(self, config):
super().__init__()
super(XLMPredLayer, self).__init__()
self.asm = config.asm
self.n_words = config.n_words
self.pad_index = config.pad_index
dim = config.emb_dim
if config.asm is False:
self.proj = Linear(dim, config.n_words, bias=True)
self.proj = nn.Linear(dim, config.n_words, bias=True)
else:
self.proj = nn.AdaptiveLogSoftmaxWithLoss(
in_features=dim,
......@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
head_bias=True, # default is False
)
def forward(self, x, y, get_scores=False):
def forward(self, x, y=None):
""" Compute the loss, and optionally the scores.
"""
Compute the loss, and optionally the scores.
"""
assert (y == self.pad_index).sum().item() == 0
outputs = ()
if self.asm is False:
scores = self.proj(x).view(-1, self.n_words)
outputs = (scores,) + outputs
if y is not None:
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
outputs = (loss,) + outputs
else:
scores = self.proj.log_prob(x)
outputs = (scores,) + outputs
if y is not None:
_, loss = self.proj(x, y)
scores = self.proj.log_prob(x) if get_scores else None
return scores, loss
def get_scores(self, x):
"""
Compute scores.
"""
assert x.dim() == 2
return self.proj.log_prob(x) if self.asm else self.proj(x)
outputs = (loss,) + outputs
return outputs
class XLMWithLMHeadModel(XLMPreTrainedModel):
......@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
"""
def __init__(self, config):
super(XLMWithLMHeadModel, self).__init__(config)
self.torchscript = config.torchscript
self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config)
......@@ -741,6 +740,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
if self.torchscript:
self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
else:
self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
......@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
logits = self.pred_layer(output, labels)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if labels is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
labels.view(-1))
outputs = [loss] + outputs
outputs = [logits] + outputs
outputs = self.pred_layer(output, labels)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
return outputs
class XLMSequenceSummary(nn.Module):
def __init__(self, config):
super(XLMSequenceSummary, self).__init__()
self.summary_type = config.summary_type
if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model)
else:
self.summary = None
if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if self.summary_type == 'last':
output = hidden_states[:, -1]
elif self.summary_type == 'first':
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=1)
elif summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.activation(output)
output = self.dropout(output)
return output
class XLMForSequenceClassification(XLMPreTrainedModel):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
......@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
"""
def __init__(self, config):
super(XLMForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLMModel(config)
self.sequence_summary = XLMSequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.sequence_summary = SequenceSummary(config)
self.apply(self.init_weights)
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None,
cache=None, labels=None, head_mask=None):
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
......@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
logits = self.sequence_summary(output)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if labels is not None:
if self.num_labels == 1:
......@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs
outputs = [logits] + outputs
outputs = (loss,) + outputs
return outputs
......@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
super(XLMForQuestionAnswering, self).__init__(config)
self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.qa_outputs = SQuADHead(config)
self.apply(self.init_weights)
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, cache=None,
labels=None, head_mask=None):
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs
outputs = [start_logits, end_logits] + outputs
outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
return outputs
......@@ -32,8 +32,8 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, SequenceSummary)
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
logger = logging.getLogger(__name__)
......@@ -228,6 +228,8 @@ class XLNetConfig(PretrainedConfig):
summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs):
"""Constructs XLNetConfig.
......@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""
def __init__(self, config):
super(XLNetForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLNetModel(config)
self.sequence_summary = SequenceSummary(config)
......@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""XLNet model for Question Answering (span extraction).
""" XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
......@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""
def __init__(self, config):
super(XLNetForQuestionAnswering, self).__init__(config)
self.start_n_top = config.start_n_top
self.end_n_top = config.end_n_top
self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.d_model, config.num_labels)
self.start_logits = PoolerStartLogits(config)
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None, head_mask=None):
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
head_mask=None):
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask)
logits = self.qa_outputs(transformer_outputs[0])
outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else:
outputs = (total_loss, start_logits, end_logits) + outputs
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
return outputs
......@@ -16,11 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
......
......@@ -20,7 +20,7 @@ import unittest
import shutil
import pytest
from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
......@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
summary_type="last",
use_proj=True,
scope=None,
all_model_classes = (XLMModel,), # , XLMForSequenceClassification, XLMForTokenClassification),
all_model_classes = (XLMModel, XLMWithLMHeadModel,
XLMForQuestionAnswering, XLMForSequenceClassification), # , XLMForSequenceClassification, XLMForTokenClassification),
):
self.parent = parent
self.batch_size = batch_size
......@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_lengths = None
if self.use_input_lengths:
......@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
sequence_labels = None
token_labels = None
choice_labels = None
is_impossible_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLMConfig(
vocab_size_or_config_json_file=self.vocab_size,
......@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
summary_type=self.summary_type,
use_proj=self.use_proj)
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMModel(config=config)
model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
......@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.hidden_size])
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# model = XLMForMaskedLM(config=config)
# model.eval()
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# model = XLMForQuestionAnswering(config=config)
# model.eval()
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
# "start_logits": start_logits,
# "end_logits": end_logits,
# }
# self.parent.assertListEqual(
# list(result["start_logits"].size()),
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
# list(result["end_logits"].size()),
# [self.batch_size, self.seq_length])
# self.check_loss_output(result)
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_labels = self.num_labels
# model = XLMForSequenceClassification(config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_labels = self.num_labels
# model = XLMForTokenClassification(config=config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.seq_length, self.num_labels])
# self.check_loss_output(result)
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_choices = self.num_choices
# model = XLMForMultipleChoice(config=config)
# model.eval()
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
# multiple_choice_token_type_ids,
# multiple_choice_input_lengths,
# choice_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_choices])
# self.check_loss_output(result)
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMWithLMHeadModel(config)
model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnswering(config)
model.eval()
outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask)
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, start_logits, end_logits = outputs
result = {
"loss": total_loss,
"start_logits": start_logits,
"end_logits": end_logits,
"cls_logits": cls_logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])
def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForSequenceClassification(config)
model.eval()
(logits,) = model(input_ids)
loss, logits = model(input_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
create_and_check_commons(self, config, inputs_dict)
......
......@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
d_inner=128,
num_hidden_layers=5,
max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
same_length=False,
......@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
self.initializer_range = initializer_range
self.seed = seed
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
......@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
target_mapping[:, 0, -1] = 1.0 # predict last token
inp_q = target_mapping[:, 0, :].clone() # predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
sequence_labels = None
lm_labels = None
is_impossible_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
......@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data,
initializer_range=self.initializer_range)
initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetModel(config)
model.eval()
_, _ = model(input_ids_1, input_mask=input_mask)
_, _ = model(input_ids_1, attention_mask=input_mask)
_, _ = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1)
......@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetLMHeadModel(config)
model.eval()
......@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForQuestionAnswering(config)
model.eval()
outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask)
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, start_logits, end_logits, mems = outputs
result = {
"loss": total_loss,
"start_logits": start_logits,
"end_logits": end_logits,
"cls_logits": cls_logits,
"mems": mems,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForSequenceClassification(config)
model.eval()
logits, mems_1 = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict, test_pruning=False)
......@@ -228,23 +295,15 @@ class XLNetModelTest(unittest.TestCase):
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_commons(*config_and_inputs)
@classmethod
def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a tensor with padding on the right (0.0 for )."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_qa(*config_and_inputs)
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_commons(*config_and_inputs)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment