"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4e6a3172cecef53f790f1c995c7569ca11e04444"
Commit c9591f6f authored by thomwolf's avatar thomwolf
Browse files

updated models input format + tests

parent c014d1f0
...@@ -456,24 +456,23 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -456,24 +456,23 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# def call(self, input_ids, attention_mask=None, token_type_ids=None, # def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False): # position_ids=None, head_mask=None, training=False):
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else None token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else None position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else None head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', None) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1) attention_mask = tf.fill(tf.shape(input_ids), 1)
...@@ -637,8 +636,8 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -637,8 +636,8 @@ class TFBertModel(TFBertPreTrainedModel):
super(TFBertModel, self).__init__(config, *inputs, **kwargs) super(TFBertModel, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
return outputs return outputs
...@@ -676,11 +675,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -676,11 +675,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.nsp = TFBertNSPHead(config, name='nsp___cls') self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
...@@ -718,11 +717,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -718,11 +717,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
...@@ -761,8 +760,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -761,8 +760,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.nsp = TFBertNSPHead(config, name='nsp___cls') self.nsp = TFBertNSPHead(config, name='nsp___cls')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
...@@ -805,12 +804,12 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): ...@@ -805,12 +804,12 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
...@@ -852,24 +851,23 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -852,24 +851,23 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier') self.classifier = tf.keras.layers.Dense(1, name='classifier')
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else None token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else None position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else None head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', None) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
num_choices = tf.shape(input_ids)[1] num_choices = tf.shape(input_ids)[1]
seq_length = tf.shape(input_ids)[2] seq_length = tf.shape(input_ids)[2]
...@@ -927,12 +925,12 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): ...@@ -927,12 +925,12 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
...@@ -976,8 +974,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): ...@@ -976,8 +974,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -418,20 +418,19 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -418,20 +418,19 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
(attention_mask, head_mask) = None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else None head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs." assert len(inputs) <= 3, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs." assert len(inputs) <= 3, "Too many inputs."
else:
input_ids = inputs
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length) attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length)
...@@ -532,8 +531,8 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel): ...@@ -532,8 +531,8 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
super(TFDistilBertModel, self).__init__(config, *inputs, **kwargs) super(TFDistilBertModel, self).__init__(config, *inputs, **kwargs)
self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.distilbert(inputs, training=training) outputs = self.distilbert(inputs, **kwargs)
return outputs return outputs
...@@ -603,18 +602,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): ...@@ -603,18 +602,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
dlbrt_output = self.distilbert(inputs, training=training) distilbert_output = self.distilbert(inputs, **kwargs)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) prediction_logits = self.vocab_projector(prediction_logits)
outputs = (prediction_logits, ) + dlbrt_output[1:] outputs = (prediction_logits,) + distilbert_output[1:]
return outputs # logits, (hidden_states), (attentions)
return outputs # prediction_logits, (all hidden_states), (all attentions)
@add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
...@@ -660,12 +658,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel): ...@@ -660,12 +658,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier") self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout) self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, training=training) distilbert_output = self.distilbert(inputs, **kwargs)
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False)) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:] outputs = (logits,) + distilbert_output[1:]
...@@ -720,11 +719,11 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel): ...@@ -720,11 +719,11 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
assert config.num_labels == 2 assert config.num_labels == 2
self.dropout = tf.keras.layers.Dropout(config.qa_dropout) self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, training=training) distilbert_output = self.distilbert(inputs, **kwargs)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=kwargs.get('training', False)) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
......
...@@ -230,26 +230,25 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -230,26 +230,25 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else None past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else None attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else None token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else None position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else None head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', None) past = inputs.get('past', past)
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', None) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -442,8 +441,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -442,8 +441,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs) super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, training=training) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
...@@ -483,8 +482,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -483,8 +482,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.transformer.wte(hidden_states, mode="linear") lm_logits = self.transformer.wte(hidden_states, mode="linear")
...@@ -551,28 +550,27 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -551,28 +550,27 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
def call(self, inputs, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mc_token_ids = inputs[1] if len(inputs) > 1 else None past = inputs[1] if len(inputs) > 1 else past
past = inputs[2] if len(inputs) > 2 else None attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
attention_mask = inputs[3] if len(inputs) > 3 else None token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
token_type_ids = inputs[4] if len(inputs) > 4 else None position_ids = inputs[4] if len(inputs) > 4 else position_ids
position_ids = inputs[5] if len(inputs) > 5 else None head_mask = inputs[5] if len(inputs) > 5 else head_mask
head_mask = inputs[6] if len(inputs) > 6 else None mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mc_token_ids = inputs.get('mc_token_ids', None) past = inputs.get('past', past)
past = inputs.get('past', None) attention_mask = inputs.get('attention_mask', attention_mask)
attention_mask = inputs.get('attention_mask', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
token_type_ids = inputs.get('token_type_ids', None) position_ids = inputs.get('position_ids', position_ids)
position_ids = inputs.get('position_ids', None) head_mask = inputs.get('head_mask', head_mask)
head_mask = inputs.get('head_mask', None) mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
input_shapes = shape_list(input_ids) input_shapes = shape_list(input_ids)
......
...@@ -229,24 +229,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -229,24 +229,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else None token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else None position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else None head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', None) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
if position_ids is None: if position_ids is None:
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :]
...@@ -420,8 +419,8 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): ...@@ -420,8 +419,8 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTModel, self).__init__(config, *inputs, **kwargs) super(TFOpenAIGPTModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, training=training) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
...@@ -455,8 +454,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel): ...@@ -455,8 +454,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
...@@ -511,26 +510,25 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -511,26 +510,25 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
mc_token_ids, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mc_token_ids = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
attention_mask = inputs[2] if len(inputs) > 2 else None token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
token_type_ids = inputs[3] if len(inputs) > 3 else None position_ids = inputs[3] if len(inputs) > 3 else position_ids
position_ids = inputs[4] if len(inputs) > 4 else None head_mask = inputs[4] if len(inputs) > 4 else head_mask
head_mask = inputs[5] if len(inputs) > 5 else None mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mc_token_ids = inputs.get('mc_token_ids', None) attention_mask = inputs.get('attention_mask', attention_mask)
attention_mask = inputs.get('attention_mask', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
token_type_ids = inputs.get('token_type_ids', None) position_ids = inputs.get('position_ids', position_ids)
position_ids = inputs.get('position_ids', None) head_mask = inputs.get('head_mask', head_mask)
head_mask = inputs.get('head_mask', None) mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
input_shapes = shape_list(input_ids) input_shapes = shape_list(input_ids)
......
...@@ -73,21 +73,21 @@ class TFRobertaMainLayer(TFBertMainLayer): ...@@ -73,21 +73,21 @@ class TFRobertaMainLayer(TFBertMainLayer):
super(TFRobertaMainLayer, self).__init__(config, **kwargs) super(TFRobertaMainLayer, self).__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name='embeddings') self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
# Check that input_ids starts with control token # Check that input_ids starts with control token
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
else:
input_ids = inputs
if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0): if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0):
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. " logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. " "This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.") "Please specify add_special_tokens=True in your encoding.")
return super(TFRobertaMainLayer, self).call(inputs, training=training) return super(TFRobertaMainLayer, self).call(inputs, **kwargs)
class TFRobertaPreTrainedModel(TFPreTrainedModel): class TFRobertaPreTrainedModel(TFPreTrainedModel):
...@@ -203,8 +203,8 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -203,8 +203,8 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
super(TFRobertaModel, self).__init__(config, *inputs, **kwargs) super(TFRobertaModel, self).__init__(config, *inputs, **kwargs)
self.roberta = TFRobertaMainLayer(config, name='roberta') self.roberta = TFRobertaMainLayer(config, name='roberta')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, training=training) outputs = self.roberta(inputs, **kwargs)
return outputs return outputs
...@@ -277,8 +277,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): ...@@ -277,8 +277,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta") self.roberta = TFRobertaMainLayer(config, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, training=training) outputs = self.roberta(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
...@@ -347,8 +347,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel): ...@@ -347,8 +347,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta") self.roberta = TFRobertaMainLayer(config, name="roberta")
self.classifier = TFRobertaClassificationHead(config, name="classifier") self.classifier = TFRobertaClassificationHead(config, name="classifier")
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, training=training) outputs = self.roberta(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=training) logits = self.classifier(sequence_output, training=training)
......
...@@ -447,20 +447,19 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -447,20 +447,19 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems return new_mems
def call(self, inputs, training=False): def call(self, inputs, mems=None, head_mask=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
mems, head_mask = None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else None head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs." assert len(inputs) <= 3, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None) mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs." assert len(inputs) <= 3, "Too many inputs."
else:
input_ids = inputs
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
...@@ -632,8 +631,8 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -632,8 +631,8 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs) super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFTransfoXLMainLayer(config, name='transformer') self.transformer = TFTransfoXLMainLayer(config, name='transformer')
def call(self, inputs, training=False, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, training=training, **kwargs) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
...@@ -694,22 +693,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -694,22 +693,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
def init_mems(self, data): def init_mems(self, data):
return self.transformer.init_mems(data) return self.transformer.init_mems(data)
def call(self, inputs, training=False): def call(self, inputs, mems=None, head_mask=None, labels=None, training=False):
if not isinstance(inputs, (dict, tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs
mems, head_mask, labels = None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else None head_mask = inputs[2] if len(inputs) > 2 else head_mask
labels = inputs[3] if len(inputs) > 3 else None labels = inputs[3] if len(inputs) > 3 else labels
assert len(inputs) <= 4, "Too many inputs." assert len(inputs) <= 4, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None) mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
labels = inputs.get('labels', None) labels = inputs.get('labels', labels)
assert len(inputs) <= 4, "Too many inputs." assert len(inputs) <= 4, "Too many inputs."
else:
input_ids = inputs
bsz, tgt_len = shape_list(input_ids)[:2] bsz, tgt_len = shape_list(input_ids)[:2]
......
...@@ -294,31 +294,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -294,31 +294,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): # removed: src_enc=None, src_len=None def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
if not isinstance(inputs, (dict, tuple, list)): position_ids=None, lengths=None, cache=None, head_mask=None,
input_ids = inputs training=False): # removed: src_enc=None, src_len=None
(attention_mask, langs, token_type_ids, position_ids, if isinstance(inputs, (tuple, list)):
lengths, cache, head_mask) = None, None, None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else None langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else None token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else None position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else None lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else None cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else None head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
langs = inputs.get('langs', None) langs = inputs.get('langs', langs)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', None) position_ids = inputs.get('position_ids', position_ids)
lengths = inputs.get('lengths', None) lengths = inputs.get('lengths', lengths)
cache = inputs.get('cache', None) cache = inputs.get('cache', cache)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
if lengths is None: if lengths is None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
...@@ -538,8 +538,8 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -538,8 +538,8 @@ class TFXLMModel(TFXLMPreTrainedModel):
super(TFXLMModel, self).__init__(config, *inputs, **kwargs) super(TFXLMModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer') self.transformer = TFXLMMainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, training=training) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
...@@ -619,8 +619,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -619,8 +619,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj') self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
...@@ -670,8 +670,8 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel): ...@@ -670,8 +670,8 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer') self.transformer = TFXLMMainLayer(config, name='transformer')
self.sequence_summary = TFSequenceSummary(config, name='sequence_summary') self.sequence_summary = TFSequenceSummary(config, name='sequence_summary')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
...@@ -731,8 +731,8 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel): ...@@ -731,8 +731,8 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer') self.transformer = TFXLMMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
......
...@@ -489,31 +489,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -489,31 +489,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb return pos_emb
def call(self, inputs, training=False): def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
if not isinstance(inputs, (dict, tuple, list)): token_type_ids=None, input_mask=None, head_mask=None, training=False):
input_ids = inputs if isinstance(inputs, (tuple, list)):
(attention_mask, mems, perm_mask, target_mapping,
token_type_ids, input_mask, head_mask) = None, None, None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else None mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else None perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else None target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else None token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else None input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else None head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else: elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None) attention_mask = inputs.get('attention_mask', attention_mask)
mems = inputs.get('mems', None) mems = inputs.get('mems', mems)
perm_mask = inputs.get('perm_mask', None) perm_mask = inputs.get('perm_mask', perm_mask)
target_mapping = inputs.get('target_mapping', None) target_mapping = inputs.get('target_mapping', target_mapping)
token_type_ids = inputs.get('token_type_ids', None) token_type_ids = inputs.get('token_type_ids', token_type_ids)
input_mask = inputs.get('input_mask', None) input_mask = inputs.get('input_mask', input_mask)
head_mask = inputs.get('head_mask', None) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
...@@ -784,8 +783,8 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -784,8 +783,8 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super(TFXLNetModel, self).__init__(config, *inputs, **kwargs) super(TFXLNetModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name='transformer') self.transformer = TFXLNetMainLayer(config, name='transformer')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, training=training) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
...@@ -829,8 +828,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -829,8 +828,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer') self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss') self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state) logits = self.lm_loss(hidden_state)
...@@ -886,8 +885,8 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): ...@@ -886,8 +885,8 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
self.sequence_summary = TFSequenceSummary(config, name='sequence_summary') self.sequence_summary = TFSequenceSummary(config, name='sequence_summary')
self.logits_proj = tf.keras.layers.Dense(config.num_labels, name='logits_proj') self.logits_proj = tf.keras.layers.Dense(config.num_labels, name='logits_proj')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
...@@ -933,8 +932,8 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): ...@@ -933,8 +932,8 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer') self.transformer = TFXLNetMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, **kwargs)
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
......
...@@ -138,7 +138,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -138,7 +138,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
inputs = {'input_ids': input_ids, inputs = {'input_ids': input_ids,
'attention_mask': input_mask, 'attention_mask': input_mask,
'token_type_ids': token_type_ids} 'token_type_ids': token_type_ids}
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) sequence_output, pooled_output = model(inputs)
inputs = [input_ids, input_mask] inputs = [input_ids, input_mask]
sequence_output, pooled_output = model(inputs) sequence_output, pooled_output = model(inputs)
......
...@@ -29,6 +29,7 @@ from pytorch_transformers import is_tf_available ...@@ -29,6 +29,7 @@ from pytorch_transformers import is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy as np
from pytorch_transformers import TFPreTrainedModel from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP # from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
...@@ -65,6 +66,22 @@ class TFCommonTestCases: ...@@ -65,6 +66,22 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs_dict = model(inputs_dict)
inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop('input_ids')
outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment