Commit 465870c3 authored by thomwolf's avatar thomwolf
Browse files

Xlnet working - also added simple question answering model for XLNet

parent 16b63617
...@@ -68,7 +68,8 @@ if _torch_available: ...@@ -68,7 +68,8 @@ if _torch_available:
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel, from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification, XLMWithLMHeadModel, XLMForSequenceClassification,
...@@ -112,6 +113,12 @@ if _tf_available: ...@@ -112,6 +113,12 @@ if _tf_available:
load_gpt2_pt_weights_in_tf2, load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple,
load_xlnet_pt_weights_in_tf2,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
# Files and general utilities # Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
......
...@@ -175,7 +175,7 @@ class PretrainedConfig(object): ...@@ -175,7 +175,7 @@ class PretrainedConfig(object):
"""Constructs a `Config` from a Python dictionary of parameters.""" """Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1) config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items(): for key, value in json_object.items():
config.__dict__[key] = value setattr(config, key, value)
return config return config
@classmethod @classmethod
......
...@@ -112,7 +112,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -112,7 +112,7 @@ class XLNetConfig(PretrainedConfig):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value setattr(config, key, value)
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file self.n_token = vocab_size_or_config_json_file
self.d_model = d_model self.d_model = d_model
......
...@@ -24,12 +24,13 @@ import tensorflow as tf ...@@ -24,12 +24,13 @@ import tensorflow as tf
from pytorch_transformers import is_torch_available from pytorch_transformers import is_torch_available
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2) GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2)
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel
else: else:
BertForPreTraining, GPT2LMHeadModel = None, None BertForPreTraining, GPT2LMHeadModel = None, None
...@@ -40,6 +41,7 @@ logging.basicConfig(level=logging.INFO) ...@@ -40,6 +41,7 @@ logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining), 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
...@@ -50,6 +52,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -50,6 +52,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Initialise TF model # Initialise TF model
config = config_class.from_json_file(config_file) config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
print("Building TensorFlow model from configuration: {}".format(str(config))) print("Building TensorFlow model from configuration: {}".format(str(config)))
tf_model = model_class(config) tf_model = model_class(config)
......
...@@ -83,7 +83,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): ...@@ -83,7 +83,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers) name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
name = name.replace(':0', '') name = name.replace(':0', '')
name = name.replace('layer_', 'layer/') name = name.replace('__', '/')
name = name.split('/') name = name.split('/')
name = name[1:] name = name[1:]
...@@ -391,7 +391,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -391,7 +391,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
super(TFBertEncoder, self).__init__(**kwargs) super(TFBertEncoder, self).__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name='layer__{}'.format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
......
...@@ -70,7 +70,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): ...@@ -70,7 +70,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
name = symbolic_weight.name name = symbolic_weight.name
name = name.replace(':0', '') name = name.replace(':0', '')
name = name.replace('h_', 'h/') name = name.replace('__', '/')
name = name.split('/') name = name.split('/')
name = name[2:] name = name[2:]
...@@ -282,7 +282,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -282,7 +282,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.h = [TFBlock(config.n_ctx, self.h = [TFBlock(config.n_ctx,
config, config,
scale=True, scale=True,
name='h_{}'.format(i)) for i in range(config.n_layer)] name='h__{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
......
...@@ -386,7 +386,7 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -386,7 +386,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
self.activation = None self.activation = None
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.activation = tf.keras.layers.Tanh() self.activation = tf.keras.activations.tanh
self.first_dropout = None self.first_dropout = None
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
......
This diff is collapsed.
...@@ -1003,6 +1003,101 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1003,6 +1003,101 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForQuestionAnsweringSimple, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
......
...@@ -28,9 +28,10 @@ from pytorch_transformers import XLNetConfig, is_tf_available ...@@ -28,9 +28,10 @@ from pytorch_transformers import XLNetConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
# XLNetLMHeadModel, TFXLNetForSequenceClassification,
# XLNetForSequenceClassification, XLNetForQuestionAnswering) TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
pytestmark = pytest.mark.skip("Require TensorFlow") pytestmark = pytest.mark.skip("Require TensorFlow")
...@@ -39,9 +40,9 @@ from .configuration_common_test import ConfigTester ...@@ -39,9 +40,9 @@ from .configuration_common_test import ConfigTester
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes=(TFXLNetModel, ) if is_tf_available() else () all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification,
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else () TFXLNetForQuestionAnsweringSimple) if is_tf_available() else ()
test_pruning = False test_pruning = False
class TFXLNetModelTester(object): class TFXLNetModelTester(object):
...@@ -169,128 +170,88 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -169,128 +170,88 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass model = TFXLNetLMHeadModel(config)
# model = XLNetLMHeadModel(config)
# model.eval() inputs_1 = {'input_ids': input_ids_1,
'token_type_ids': segment_ids}
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
all_logits_1, mems_1 = model(inputs_1)
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
inputs_2 = {'input_ids': input_ids_2,
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) 'mems': mems_1,
'token_type_ids': segment_ids}
# result = {
# "loss_1": loss_1, all_logits_2, mems_2 = model(inputs_2)
# "mems_1": mems_1,
# "all_logits_1": all_logits_1, inputs_3 = {'input_ids': input_ids_q,
# "loss_2": loss_2, 'perm_mask': perm_mask,
# "mems_2": mems_2, 'target_mapping': target_mapping}
# "all_logits_2": all_logits_2,
# } logits, _ = model(inputs_3)
# self.parent.assertListEqual( result = {
# list(result["loss_1"].size()), "mems_1": [mem.numpy() for mem in mems_1],
# []) "all_logits_1": all_logits_1.numpy(),
# self.parent.assertListEqual( "mems_2": [mem.numpy() for mem in mems_2],
# list(result["all_logits_1"].size()), "all_logits_2": all_logits_2.numpy(),
# [self.batch_size, self.seq_length, self.vocab_size]) }
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]), self.parent.assertListEqual(
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) list(result["all_logits_1"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual( self.parent.assertListEqual(
# list(result["loss_2"].size()), list(list(mem.shape) for mem in result["mems_1"]),
# []) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# self.parent.assertListEqual(
# list(result["all_logits_2"].size()), self.parent.assertListEqual(
# [self.batch_size, self.seq_length, self.vocab_size]) list(result["all_logits_2"].shape),
# self.parent.assertListEqual( [self.batch_size, self.seq_length, self.vocab_size])
# list(list(mem.size()) for mem in result["mems_2"]), self.parent.assertListEqual(
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass model = TFXLNetForQuestionAnsweringSimple(config)
# model = XLNetForQuestionAnswering(config)
# model.eval() inputs = {'input_ids': input_ids_1,
'attention_mask': input_mask,
# outputs = model(input_ids_1) 'token_type_ids': segment_ids}
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs start_logits, end_logits, mems = model(inputs)
# outputs = model(input_ids_1, start_positions=sequence_labels, result = {
# end_positions=sequence_labels, "start_logits": start_logits.numpy(),
# cls_index=sequence_labels, "end_logits": end_logits.numpy(),
# is_impossible=is_impossible_labels, "mems": [m.numpy() for m in mems],
# p_mask=input_mask) }
# outputs = model(input_ids_1, start_positions=sequence_labels, self.parent.assertListEqual(
# end_positions=sequence_labels, list(result["start_logits"].shape),
# cls_index=sequence_labels, [self.batch_size, self.seq_length])
# is_impossible=is_impossible_labels) self.parent.assertListEqual(
list(result["end_logits"].shape),
# total_loss, mems = outputs [self.batch_size, self.seq_length])
self.parent.assertListEqual(
# outputs = model(input_ids_1, start_positions=sequence_labels, list(list(mem.shape) for mem in result["mems"]),
# end_positions=sequence_labels) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# total_loss, mems = outputs
# result = {
# "loss": total_loss,
# "start_top_log_probs": start_top_log_probs,
# "start_top_index": start_top_index,
# "end_top_log_probs": end_top_log_probs,
# "end_top_index": end_top_index,
# "cls_logits": cls_logits,
# "mems": mems,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["start_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["start_top_index"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["end_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["end_top_index"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["cls_logits"].size()),
# [self.batch_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass model = TFXLNetForSequenceClassification(config)
# model = XLNetForSequenceClassification(config)
# model.eval() logits, mems_1 = model(input_ids_1)
# logits, mems_1 = model(input_ids_1) result = {
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels) "mems_1": [mem.numpy() for mem in mems_1],
"logits": logits.numpy(),
# result = { }
# "loss": loss,
# "mems_1": mems_1, self.parent.assertListEqual(
# "logits": logits, list(result["logits"].shape),
# } [self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual(
# self.parent.assertListEqual( list(list(mem.shape) for mem in result["mems_1"]),
# list(result["loss"].size()), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# [])
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.type_sequence_label_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment