"vscode:/vscode.git/clone" did not exist on "d2e5b19b821f0cf43c7cf4f01be5faa1cb20aa64"
Commit c9591f6f authored by thomwolf's avatar thomwolf
Browse files

updated models input format + tests

parent c014d1f0
......@@ -456,24 +456,23 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1)
......@@ -637,8 +636,8 @@ class TFBertModel(TFBertPreTrainedModel):
super(TFBertModel, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
return outputs
......@@ -676,11 +675,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=training)
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
seq_relationship_score = self.nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
......@@ -718,11 +717,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
......@@ -761,8 +760,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.nsp = TFBertNSPHead(config, name='nsp___cls')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output)
......@@ -805,12 +804,12 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
......@@ -852,24 +851,23 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier')
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
num_choices = tf.shape(input_ids)[1]
seq_length = tf.shape(input_ids)[2]
......@@ -927,12 +925,12 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
......@@ -976,8 +974,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
sequence_output = outputs[0]
......
......@@ -418,20 +418,19 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
(attention_mask, head_mask) = None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
head_mask = inputs[2] if len(inputs) > 2 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs."
else:
input_ids = inputs
if attention_mask is None:
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length)
......@@ -532,8 +531,8 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
super(TFDistilBertModel, self).__init__(config, *inputs, **kwargs)
self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings
def call(self, inputs, training=False):
outputs = self.distilbert(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.distilbert(inputs, **kwargs)
return outputs
......@@ -603,18 +602,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def call(self, inputs, training=False):
dlbrt_output = self.distilbert(inputs, training=training)
def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits)
outputs = (prediction_logits, ) + dlbrt_output[1:]
return outputs # prediction_logits, (all hidden_states), (all attentions)
outputs = (prediction_logits,) + distilbert_output[1:]
return outputs # logits, (hidden_states), (attentions)
@add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
......@@ -660,12 +658,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
def call(self, inputs, training=False):
distilbert_output = self.distilbert(inputs, training=training)
def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False)) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:]
......@@ -720,11 +719,11 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
assert config.num_labels == 2
self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
def call(self, inputs, training=False):
distilbert_output = self.distilbert(inputs, training=training)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs)
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=kwargs.get('training', False)) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
......
......@@ -230,26 +230,25 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else None
attention_mask = inputs[2] if len(inputs) > 2 else None
token_type_ids = inputs[3] if len(inputs) > 3 else None
position_ids = inputs[4] if len(inputs) > 4 else None
head_mask = inputs[5] if len(inputs) > 5 else None
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
past = inputs.get('past', None)
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
past = inputs.get('past', past)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
if past is None:
past_length = 0
......@@ -442,8 +441,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer')
def call(self, inputs, training=False):
outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -483,8 +482,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
lm_logits = self.transformer.wte(hidden_states, mode="linear")
......@@ -551,28 +550,27 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mc_token_ids = inputs[1] if len(inputs) > 1 else None
past = inputs[2] if len(inputs) > 2 else None
attention_mask = inputs[3] if len(inputs) > 3 else None
token_type_ids = inputs[4] if len(inputs) > 4 else None
position_ids = inputs[5] if len(inputs) > 5 else None
head_mask = inputs[6] if len(inputs) > 6 else None
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
assert len(inputs) <= 7, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
mc_token_ids = inputs.get('mc_token_ids', None)
past = inputs.get('past', None)
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
past = inputs.get('past', past)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
input_shapes = shape_list(input_ids)
......
......@@ -229,24 +229,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
if position_ids is None:
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :]
......@@ -420,8 +419,8 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def call(self, inputs, training=False):
outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -455,8 +454,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
......@@ -511,26 +510,25 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
mc_token_ids, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mc_token_ids = inputs[1] if len(inputs) > 1 else None
attention_mask = inputs[2] if len(inputs) > 2 else None
token_type_ids = inputs[3] if len(inputs) > 3 else None
position_ids = inputs[4] if len(inputs) > 4 else None
head_mask = inputs[5] if len(inputs) > 5 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids
assert len(inputs) <= 6, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
mc_token_ids = inputs.get('mc_token_ids', None)
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
input_shapes = shape_list(input_ids)
......
......@@ -73,21 +73,21 @@ class TFRobertaMainLayer(TFBertMainLayer):
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
def call(self, inputs, training=False):
def call(self, inputs, **kwargs):
# Check that input_ids starts with control token
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
elif isinstance(inputs, (tuple, list)):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
else:
input_ids = inputs
if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0):
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.")
return super(TFRobertaMainLayer, self).call(inputs, training=training)
return super(TFRobertaMainLayer, self).call(inputs, **kwargs)
class TFRobertaPreTrainedModel(TFPreTrainedModel):
......@@ -203,8 +203,8 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
super(TFRobertaModel, self).__init__(config, *inputs, **kwargs)
self.roberta = TFRobertaMainLayer(config, name='roberta')
def call(self, inputs, training=False):
outputs = self.roberta(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs)
return outputs
......@@ -277,8 +277,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def call(self, inputs, training=False):
outputs = self.roberta(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
......@@ -347,8 +347,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.classifier = TFRobertaClassificationHead(config, name="classifier")
def call(self, inputs, training=False):
outputs = self.roberta(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=training)
......
......@@ -447,20 +447,19 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
mems, head_mask = None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, mems=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None
head_mask = inputs[2] if len(inputs) > 2 else None
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None)
head_mask = inputs.get('head_mask', None)
mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs."
else:
input_ids = inputs
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
......@@ -632,8 +631,8 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFTransfoXLMainLayer(config, name='transformer')
def call(self, inputs, training=False, **kwargs):
outputs = self.transformer(inputs, training=training, **kwargs)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -694,22 +693,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
def init_mems(self, data):
return self.transformer.init_mems(data)
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
mems, head_mask, labels = None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, mems=None, head_mask=None, labels=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else None
head_mask = inputs[2] if len(inputs) > 2 else None
labels = inputs[3] if len(inputs) > 3 else None
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
labels = inputs[3] if len(inputs) > 3 else labels
assert len(inputs) <= 4, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
mems = inputs.get('mems', None)
head_mask = inputs.get('head_mask', None)
labels = inputs.get('labels', None)
mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', head_mask)
labels = inputs.get('labels', labels)
assert len(inputs) <= 4, "Too many inputs."
else:
input_ids = inputs
bsz, tgt_len = shape_list(input_ids)[:2]
......
......@@ -294,31 +294,31 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
def call(self, inputs, training=False): # removed: src_enc=None, src_len=None
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
(attention_mask, langs, token_type_ids, position_ids,
lengths, cache, head_mask) = None, None, None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
position_ids=None, lengths=None, cache=None, head_mask=None,
training=False): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
langs = inputs[2] if len(inputs) > 2 else None
token_type_ids = inputs[3] if len(inputs) > 3 else None
position_ids = inputs[4] if len(inputs) > 4 else None
lengths = inputs[5] if len(inputs) > 5 else None
cache = inputs[6] if len(inputs) > 6 else None
head_mask = inputs[7] if len(inputs) > 7 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
langs = inputs.get('langs', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
lengths = inputs.get('lengths', None)
cache = inputs.get('cache', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
langs = inputs.get('langs', langs)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
lengths = inputs.get('lengths', lengths)
cache = inputs.get('cache', cache)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
if lengths is None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
......@@ -538,8 +538,8 @@ class TFXLMModel(TFXLMPreTrainedModel):
super(TFXLMModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer')
def call(self, inputs, training=False):
outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -619,8 +619,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
......@@ -670,8 +670,8 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer')
self.sequence_summary = TFSequenceSummary(config, name='sequence_summary')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
......@@ -731,8 +731,8 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
sequence_output = transformer_outputs[0]
......
......@@ -489,31 +489,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
(attention_mask, mems, perm_mask, target_mapping,
token_type_ids, input_mask, head_mask) = None, None, None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
mems = inputs[2] if len(inputs) > 2 else None
perm_mask = inputs[3] if len(inputs) > 3 else None
target_mapping = inputs[4] if len(inputs) > 4 else None
token_type_ids = inputs[5] if len(inputs) > 5 else None
input_mask = inputs[6] if len(inputs) > 6 else None
head_mask = inputs[7] if len(inputs) > 7 else None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs."
else:
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
mems = inputs.get('mems', None)
perm_mask = inputs.get('perm_mask', None)
target_mapping = inputs.get('target_mapping', None)
token_type_ids = inputs.get('token_type_ids', None)
input_mask = inputs.get('input_mask', None)
head_mask = inputs.get('head_mask', None)
attention_mask = inputs.get('attention_mask', attention_mask)
mems = inputs.get('mems', mems)
perm_mask = inputs.get('perm_mask', perm_mask)
target_mapping = inputs.get('target_mapping', target_mapping)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
input_mask = inputs.get('input_mask', input_mask)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
......@@ -784,8 +783,8 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
super(TFXLNetModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name='transformer')
def call(self, inputs, training=False):
outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -829,8 +828,8 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state)
......@@ -886,8 +885,8 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
self.sequence_summary = TFSequenceSummary(config, name='sequence_summary')
self.logits_proj = tf.keras.layers.Dense(config.num_labels, name='logits_proj')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0]
output = self.sequence_summary(output)
......@@ -933,8 +932,8 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
sequence_output = transformer_outputs[0]
......
......@@ -138,7 +138,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(inputs)
inputs = [input_ids, input_mask]
sequence_output, pooled_output = model(inputs)
......
......@@ -29,6 +29,7 @@ from pytorch_transformers import is_tf_available
if is_tf_available():
import tensorflow as tf
import numpy as np
from pytorch_transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
......@@ -65,6 +66,22 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs_dict = model(inputs_dict)
inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop('input_ids')
outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment