Commit 465870c3 authored by thomwolf's avatar thomwolf
Browse files

Xlnet working - also added simple question answering model for XLNet

parent 16b63617
......@@ -68,7 +68,8 @@ if _torch_available:
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
......@@ -112,6 +113,12 @@ if _tf_available:
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple,
load_xlnet_pt_weights_in_tf2,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
......
......@@ -175,7 +175,7 @@ class PretrainedConfig(object):
"""Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
setattr(config, key, value)
return config
@classmethod
......
......@@ -112,7 +112,7 @@ class XLNetConfig(PretrainedConfig):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
setattr(config, key, value)
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.d_model = d_model
......
......@@ -24,12 +24,13 @@ import tensorflow as tf
from pytorch_transformers import is_torch_available
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2)
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2)
if is_torch_available():
import torch
import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel
else:
BertForPreTraining, GPT2LMHeadModel = None, None
......@@ -40,6 +41,7 @@ logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel),
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
......@@ -50,6 +52,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Initialise TF model
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
print("Building TensorFlow model from configuration: {}".format(str(config)))
tf_model = model_class(config)
......
......@@ -83,7 +83,7 @@ def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
name = name.replace(':0', '')
name = name.replace('layer_', 'layer/')
name = name.replace('__', '/')
name = name.split('/')
name = name[1:]
......@@ -391,7 +391,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
super(TFBertEncoder, self).__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
self.layer = [TFBertLayer(config, name='layer__{}'.format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
......
......@@ -70,7 +70,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '')
name = name.replace('h_', 'h/')
name = name.replace('__', '/')
name = name.split('/')
name = name[2:]
......@@ -282,7 +282,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.h = [TFBlock(config.n_ctx,
config,
scale=True,
name='h_{}'.format(i)) for i in range(config.n_layer)]
name='h__{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def _resize_token_embeddings(self, new_num_tokens):
......
......@@ -386,7 +386,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
self.activation = None
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.activation = tf.keras.layers.Tanh()
self.activation = tf.keras.activations.tanh
self.first_dropout = None
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
......
......@@ -28,7 +28,7 @@ import numpy as np
import tensorflow as tf
from .configuration_xlnet import XLNetConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .file_utils import add_start_docstrings
......@@ -72,13 +72,15 @@ def load_xlnet_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
name = name.replace(':0', '')
name = name.replace('layer_', 'layer/')
name = name.replace('layer__', 'layer/')
name = name.split('/')
name = name[1:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings':
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
name = '.'.join(name)
assert name in state_dict, "{} not found in PyTorch model".format(name)
......@@ -237,16 +239,16 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return attn_vec
def post_attention(self, inputs, training=False):
def post_attention(self, inputs, residual=True, training=False):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
h, attn_vec, residual = inputs
h, attn_vec = inputs
attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
if residual is not None:
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
......@@ -259,23 +261,23 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if g is not None:
###### Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
# content-based key head
k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k)
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.k)
# content-based value head
v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.v)
# position-based key head
k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r)
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.r)
##### h-stream
# content-stream query head
q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q)
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
......@@ -286,15 +288,15 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_vec_h, attn_prob_h = attn_vec_h
# post processing
output_h = self.post_attention([h, attn_vec_h, None], training=training)
output_h = self.post_attention([h, attn_vec_h], training=training)
##### g-stream
# query-stream query head
q_head_g = torch.einsum('ibh,hnd->ibnd', g, self.q)
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask],
training=training)
......@@ -302,7 +304,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask],
......@@ -312,15 +314,15 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_vec_g, attn_prob_g = attn_vec_g
# post processing
output_g = self.post_attention([g, attn_vec_g, None], training=training)
output_g = self.post_attention([g, attn_vec_g], training=training)
if self.output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
###### Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = tf.concat([mems, h], dim=0)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
......@@ -341,7 +343,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_vec, attn_prob = attn_vec
# post processing
output_h = self.post_attention([h, attn_vec, None], training=training)
output_h = self.post_attention([h, attn_vec], training=training)
output_g = None
outputs = (output_h, output_g)
......@@ -391,6 +393,27 @@ class TFXLNetLayer(tf.keras.layers.Layer):
return outputs
class TFXLNetLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super(TFXLNetLMHead, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
super(TFXLNetLMHead, self).build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFXLNetMainLayer, self).__init__(**kwargs)
......@@ -409,7 +432,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(config.n_token, config.d_model, initializer_range=config.initializer_range, name='word_embedding')
self.layer = [TFXLNetLayer(config, name='layer_{}'.format(i)) for i in range(config.n_layer)]
self.layer = [TFXLNetLayer(config, name='layer__{}'.format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout)
def build(self, input_shape):
......@@ -464,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:]
return tf.stop_gradient(new_mem)
......@@ -618,7 +641,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
......@@ -827,187 +850,173 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
return outputs
# @add_start_docstrings("""XLNet Model with a language modeling head on top
# (linear layer with weights tied to the input embeddings). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
# class XLNetLMHeadModel(XLNetPreTrainedModel):
# r"""
# **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
# Labels for language modeling.
# Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
# Indices are selected in ``[-1, 0, ..., config.vocab_size]``
# All labels set to ``-1`` are ignored (masked), the loss is only
# computed for labels in ``[0, ..., config.vocab_size]``
@add_start_docstrings("""XLNet Model with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
# **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
# Language modeling loss.
# **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
# Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
# **mems**:
# list of ``torch.FloatTensor`` (one for each layer):
# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
# See details in the docstring of the `mems` input above.
# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``config.output_attentions=True``)
# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
# Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = TFXLNetLMHeadModel.from_pretrained('xlnet-large-cased')
# We show how to setup inputs to predict a next token using a bi-directional context.
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>")).unsqueeze(0) # We will predict the masked token
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.n_token = config.n_token
# tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
# model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
# # We show how to setup inputs to predict a next token using a bi-directional context.
# input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>")).unsqueeze(0) # We will predict the masked token
# perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
# target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
# target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
# outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
# next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
# """
# def __init__(self, config, **kwargs):
# super(XLNetLMHeadModel, self).__init__(config)
# self.attn_type = config.attn_type
# self.same_length = config.same_length
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state)
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
# self.transformer = XLNetModel(config)
# self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
Examples::
# self.init_weights()
# self.tie_weights()
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
# def tie_weights(self):
# """ Make sure we are sharing the embeddings
# """
# self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
# token_type_ids=None, input_mask=None, head_mask=None, labels=None):
# transformer_outputs = self.transformer(input_ids,
# attention_mask=attention_mask,
# mems=mems,
# perm_mask=perm_mask,
# target_mapping=target_mapping,
# token_type_ids=token_type_ids,
# input_mask=input_mask,
# head_mask=head_mask)
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.sequence_summary = TFSequenceSummary(config, name='sequence_summary')
self.logits_proj = tf.keras.layers.Dense(config.num_labels, name='logits_proj')
# logits = self.lm_loss(transformer_outputs[0])
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
output = transformer_outputs[0]
# outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
output = self.sequence_summary(output)
logits = self.logits_proj(output)
# if labels is not None:
# # Flatten the tokens
# loss_fct = CrossEntropyLoss(ignore_index=-1)
# loss = loss_fct(logits.view(-1, logits.size(-1)),
# labels.view(-1))
# outputs = (loss,) + outputs
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
# return outputs # return (loss), logits, mems, (hidden states), (attentions)
return outputs # return logits, mems, (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
# the pooled output) e.g. for GLUE tasks. """,
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
# class XLNetForSequenceClassification(XLNetPreTrainedModel):
# r"""
# **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for computing the sequence classification/regression loss.
# Indices should be in ``[0, ..., config.num_labels - 1]``.
# If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
# If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
# class TFXLNetForQuestionAnswering(TFXLNetPreTrainedModel):
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
# **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
# Classification (or regression if config.num_labels==1) loss.
# **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
# Classification (or regression if config.num_labels==1) scores (before SoftMax).
# **mems**:
# list of ``torch.FloatTensor`` (one for each layer):
# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
# See details in the docstring of the `mems` input above.
# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``config.output_attentions=True``)
# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
# Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = TFXLNetForQuestionAnsweringSimple.from_pretrained('xlnet-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
# tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
# model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
# input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
# labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
# outputs = model(input_ids, labels=labels)
# loss, logits = outputs[:2]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForQuestionAnsweringSimple, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
# """
# def __init__(self, config, **kwargs):
# super(XLNetForSequenceClassification, self).__init__(config)
# self.num_labels = config.num_labels
# self.transformer = XLNetModel(config)
# self.sequence_summary = SequenceSummary(config)
# self.logits_proj = nn.Linear(config.d_model, config.num_labels)
# self.init_weights()
# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
# token_type_ids=None, input_mask=None, head_mask=None, labels=None):
# transformer_outputs = self.transformer(input_ids,
# attention_mask=attention_mask,
# mems=mems,
# perm_mask=perm_mask,
# target_mapping=target_mapping,
# token_type_ids=token_type_ids,
# input_mask=input_mask,
# head_mask=head_mask)
# output = transformer_outputs[0]
# output = self.sequence_summary(output)
# logits = self.logits_proj(output)
# outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
# if labels is not None:
# if self.num_labels == 1:
# # We are doing regression
# loss_fct = MSELoss()
# loss = loss_fct(logits.view(-1), labels.view(-1))
# else:
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# outputs = (loss,) + outputs
# return outputs # return (loss), logits, mems, (hidden states), (attentions)
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
sequence_output = transformer_outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
# class XLNetForQuestionAnswering(XLNetPreTrainedModel):
# class TFXLNetForQuestionAnswering(TFXLNetPreTrainedModel):
# r"""
# **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the start of the labelled span for computing the token classification loss.
# Positions are clamped to the length of the sequence (`sequence_length`).
# Position outside of the sequence are not taken into account for computing the loss.
# **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the end of the labelled span for computing the token classification loss.
# Positions are clamped to the length of the sequence (`sequence_length`).
# Position outside of the sequence are not taken into account for computing the loss.
# **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels whether a question has an answer or no answer (SQuAD 2.0)
# **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
# **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
# Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
# 1.0 means token should be masked. 0.0 mean token is not masked.
......@@ -1054,29 +1063,18 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
# loss, start_scores, end_scores = outputs[:2]
# """
# def __init__(self, config, **kwargs):
# super(XLNetForQuestionAnswering, self).__init__(config)
# def __init__(self, config, *inputs, **kwargs):
# super(TFXLNetForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
# self.start_n_top = config.start_n_top
# self.end_n_top = config.end_n_top
# self.transformer = XLNetModel(config)
# self.start_logits = PoolerStartLogits(config)
# self.end_logits = PoolerEndLogits(config)
# self.answer_class = PoolerAnswerClass(config)
# self.init_weights()
# def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
# token_type_ids=None, input_mask=None, head_mask=None,
# start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
# transformer_outputs = self.transformer(input_ids,
# attention_mask=attention_mask,
# mems=mems,
# perm_mask=perm_mask,
# target_mapping=target_mapping,
# token_type_ids=token_type_ids,
# input_mask=input_mask,
# head_mask=head_mask)
# self.transformer = TFXLNetMainLayer(config, name='transformer')
# self.start_logits = TFPoolerStartLogits(config, name='start_logits')
# self.end_logits = TFPoolerEndLogits(config, name='end_logits')
# self.answer_class = TFPoolerAnswerClass(config, name='answer_class')
# def call(self, inputs, training=False):
# transformer_outputs = self.transformer(inputs, training=training)
# hidden_states = transformer_outputs[0]
# start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
......@@ -1003,6 +1003,101 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForQuestionAnsweringSimple, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
......
......@@ -28,9 +28,10 @@ from pytorch_transformers import XLNetConfig, is_tf_available
if is_tf_available():
import tensorflow as tf
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
# XLNetLMHeadModel,
# XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
......@@ -39,9 +40,9 @@ from .configuration_common_test import ConfigTester
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes=(TFXLNetModel, ) if is_tf_available() else ()
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else ()
all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple) if is_tf_available() else ()
test_pruning = False
class TFXLNetModelTester(object):
......@@ -169,128 +170,88 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetLMHeadModel(config)
# model.eval()
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
# result = {
# "loss_1": loss_1,
# "mems_1": mems_1,
# "all_logits_1": all_logits_1,
# "loss_2": loss_2,
# "mems_2": mems_2,
# "all_logits_2": all_logits_2,
# }
# self.parent.assertListEqual(
# list(result["loss_1"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_1"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# self.parent.assertListEqual(
# list(result["loss_2"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_2"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_2"]),
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
model = TFXLNetLMHeadModel(config)
inputs_1 = {'input_ids': input_ids_1,
'token_type_ids': segment_ids}
all_logits_1, mems_1 = model(inputs_1)
inputs_2 = {'input_ids': input_ids_2,
'mems': mems_1,
'token_type_ids': segment_ids}
all_logits_2, mems_2 = model(inputs_2)
inputs_3 = {'input_ids': input_ids_q,
'perm_mask': perm_mask,
'target_mapping': target_mapping}
logits, _ = model(inputs_3)
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"all_logits_1": all_logits_1.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
"all_logits_2": all_logits_2.numpy(),
}
self.parent.assertListEqual(
list(result["all_logits_1"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["all_logits_2"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetForQuestionAnswering(config)
# model.eval()
# outputs = model(input_ids_1)
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels,
# p_mask=input_mask)
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels)
# total_loss, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels)
# total_loss, mems = outputs
# result = {
# "loss": total_loss,
# "start_top_log_probs": start_top_log_probs,
# "start_top_index": start_top_index,
# "end_top_log_probs": end_top_log_probs,
# "end_top_index": end_top_index,
# "cls_logits": cls_logits,
# "mems": mems,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["start_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["start_top_index"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["end_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["end_top_index"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["cls_logits"].size()),
# [self.batch_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
model = TFXLNetForQuestionAnsweringSimple(config)
inputs = {'input_ids': input_ids_1,
'attention_mask': input_mask,
'token_type_ids': segment_ids}
start_logits, end_logits, mems = model(inputs)
result = {
"start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(),
"mems": [m.numpy() for m in mems],
}
self.parent.assertListEqual(
list(result["start_logits"].shape),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].shape),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
pass
# model = XLNetForSequenceClassification(config)
# model.eval()
# logits, mems_1 = model(input_ids_1)
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
# result = {
# "loss": loss,
# "mems_1": mems_1,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.type_sequence_label_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
model = TFXLNetForSequenceClassification(config)
logits, mems_1 = model(input_ids_1)
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"logits": logits.numpy(),
}
self.parent.assertListEqual(
list(result["logits"].shape),
[self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment