"docs/vscode:/vscode.git/clone" did not exist on "6060b2f89b4ba3ad6d2ddb332835a95962c4bf2c"
Unverified Commit 29d49924 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
parent 82d443a7
...@@ -47,10 +47,10 @@ from ...modeling_tf_utils import ( ...@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
...@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
): # removed: src_enc=None, src_len=None **kwargs,
if isinstance(inputs, (tuple, list)): ):
input_ids = inputs[0] # removed: src_enc=None, src_len=None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask inputs = input_processing(
langs = inputs[2] if len(inputs) > 2 else langs func=self.call,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids input_ids=input_ids,
position_ids = inputs[4] if len(inputs) > 4 else position_ids attention_mask=attention_mask,
lengths = inputs[5] if len(inputs) > 5 else lengths langs=langs,
cache = inputs[6] if len(inputs) > 6 else cache token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask position_ids=position_ids,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds lengths=lengths,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions cache=cache,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states head_mask=head_mask,
return_dict = inputs[11] if len(inputs) > 11 else return_dict inputs_embeds=inputs_embeds,
assert len(inputs) <= 12, "Too many inputs." output_attentions=output_attentions,
elif isinstance(inputs, (dict, BatchEncoding)): output_hidden_states=output_hidden_states,
input_ids = inputs.get("input_ids") return_dict=return_dict,
attention_mask = inputs.get("attention_mask", attention_mask) training=training,
langs = inputs.get("langs", langs) kwargs_call=kwargs,
token_type_ids = inputs.get("token_type_ids", token_type_ids) )
position_ids = inputs.get("position_ids", position_ids) output_attentions = (
lengths = inputs.get("lengths", lengths) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
cache = inputs.get("cache", cache) )
head_mask = inputs.get("head_mask", head_mask) output_hidden_states = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
bs, slen = shape_list(input_ids) bs, slen = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs_embeds)[:2] bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None: if inputs["lengths"] is None:
if input_ids is not None: if inputs["input_ids"] is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else: else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(lengths)[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # position_ids
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" ), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched" ), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layers inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None and input_ids is not None: if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - cache["slen"] _slen = slen - inputs["cache"]["slen"]
input_ids = input_ids[:, -_slen:] inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
position_ids = position_ids[:, -_slen:] inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if langs is not None: if inputs["langs"] is not None:
langs = langs[:, -_slen:] inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:] mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embeddings(input_ids) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb and self.n_langs > 1: if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(inputs["langs"])
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# transformer layers # transformer layers
...@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self attention # self attention
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if cache is not None: if inputs["cache"] is not None:
cache["slen"] += tensor.size(1) inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
...@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
return outputs return outputs
...@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id langs = tf.ones_like(inputs) * lang_id
else: else:
langs = None langs = None
return {"inputs": inputs, "langs": langs} return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
output_type=TFXLMWithLMHeadModelOutput, output_type=TFXLMWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return_dict = kwargs.get("return_dict") self,
return_dict = return_dict if return_dict is not None else self.transformer.return_dict input_ids=None,
transformer_outputs = self.transformer(inputs, **kwargs) attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
...@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[12] if len(inputs) > 12 else labels input_ids=input_ids,
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
...@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" inputs = input_processing(
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): func=self.call,
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., input_ids=input_ids,
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See attention_mask=attention_mask,
:obj:`input_ids` above) langs=langs,
""" token_type_ids=token_type_ids,
if isinstance(inputs, (tuple, list)): position_ids=position_ids,
input_ids = inputs[0] lengths=lengths,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask cache=cache,
langs = inputs[2] if len(inputs) > 2 else langs head_mask=head_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids inputs_embeds=inputs_embeds,
position_ids = inputs[4] if len(inputs) > 4 else position_ids output_attentions=output_attentions,
lengths = inputs[5] if len(inputs) > 5 else lengths output_hidden_states=output_hidden_states,
cache = inputs[6] if len(inputs) > 6 else cache return_dict=return_dict,
head_mask = inputs[7] if len(inputs) > 7 else head_mask labels=labels,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds training=training,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions kwargs_call=kwargs,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states )
return_dict = inputs[11] if len(inputs) > 11 else return_dict return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
if lengths is not None: if inputs["lengths"] is not None:
logger.warn( logger.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.", "attention mask instead.",
) )
lengths = None inputs["lengths"] = None
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
...@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_langs, flat_langs,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
lengths, inputs["lengths"],
cache, inputs["cache"],
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
...@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[12] if len(inputs) > 12 else labels input_ids=input_ids,
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
...@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
...@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[12] if len(inputs) > 12 else start_positions input_ids=input_ids,
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
TF 2.0 XLNet model. TF 2.0 XLNet model.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -42,10 +41,10 @@ from ...modeling_tf_utils import ( ...@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
...@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
mems = inputs[2] if len(inputs) > 2 else mems attention_mask=attention_mask,
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask mems=mems,
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping perm_mask=perm_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids target_mapping=target_mapping,
input_mask = inputs[6] if len(inputs) > 6 else input_mask token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask input_mask=input_mask,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
use_cache = inputs[9] if len(inputs) > 9 else use_cache inputs_embeds=inputs_embeds,
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[12] if len(inputs) > 12 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 13, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
mems = inputs.get("mems", mems) output_attentions = (
perm_mask = inputs.get("perm_mask", perm_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
target_mapping = inputs.get("target_mapping", target_mapping) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_hidden_states = (
input_mask = inputs.get("input_mask", input_mask) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0)) inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(input_ids)[:2] qlen, bsz = shape_list(inputs["input_ids"])[:2]
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2] qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None inputs["token_type_ids"] = (
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None )
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None inputs["input_mask"] = (
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
)
inputs["attention_mask"] = (
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
)
inputs["perm_mask"] = (
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
)
inputs["target_mapping"] = (
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
)
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0 mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
...@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError("Unsupported attention type: {}".format(self.attn_type)) raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask # data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, ( assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
"You can only use one of input_mask (uses 1 for padding) " "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
) )
if input_mask is None and attention_mask is not None: if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float) inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
if input_mask is not None and perm_mask is not None: if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = input_mask[None] + perm_mask data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif input_mask is not None and perm_mask is None: elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
data_mask = input_mask[None] data_mask = inputs["input_mask"][None]
elif input_mask is None and perm_mask is not None: elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
data_mask = perm_mask data_mask = inputs["perm_mask"]
else: else:
data_mask = None data_mask = None
...@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None non_tgt_mask = None
# Word embeddings and prepare h & g hidden states # Word embeddings and prepare h & g hidden states
if inputs_embeds is not None: if inputs["inputs_embeds"] is not None:
word_emb_k = inputs_embeds word_emb_k = inputs["inputs_embeds"]
else: else:
word_emb_k = self.word_embedding(input_ids) word_emb_k = self.word_embedding(inputs["input_ids"])
output_h = self.dropout(word_emb_k, training=training) output_h = self.dropout(word_emb_k, training=inputs["training"])
if target_mapping is not None: if inputs["target_mapping"] is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1]) word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping # else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None] # inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=training) output_g = self.dropout(word_emb_q, training=inputs["training"])
else: else:
output_g = None output_g = None
# Segment embedding # Segment embedding
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0: if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0) cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else: else:
cat_ids = token_type_ids cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float) seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
else: else:
seg_mat = None seg_mat = None
# Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training) pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layer inputs["head_mask"] = [None] * self.n_layer
new_mems = () new_mems = ()
if mems is None: if inputs["mems"] is None:
mems = [None] * len(self.layer) inputs["mems"] = [None] * len(self.layer)
attentions = [] if output_attentions else None attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache: if self.mem_len is not None and self.mem_len > 0 and use_cache:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask, attn_mask,
pos_emb, pos_emb,
seg_mat, seg_mat,
mems[i], inputs["mems"][i],
target_mapping, inputs["target_mapping"],
head_mask[i], inputs["head_mask"][i],
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if output_attentions: if output_attentions:
...@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if output_hidden_states: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training) output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2)) output = tf.transpose(output, perm=(1, 0, 2))
...@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_type=TFXLNetModelOutput, output_type=TFXLNetModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
...@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = { inputs = {
"inputs": inputs, "input_ids": inputs,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": kwargs["use_cache"], "use_cache": kwargs["use_cache"],
...@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
...@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] >>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state, training=training) logits = self.lm_loss(hidden_state, training=inputs["training"])
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
...@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
...@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
token_type_ids=None, token_type_ids=None,
input_mask=None, input_mask=None,
attention_mask=None, attention_mask=None,
...@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
mems = inputs[2] if len(inputs) > 2 else mems attention_mask=attention_mask,
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask mems=mems,
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping perm_mask=perm_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids target_mapping=target_mapping,
input_mask = inputs[6] if len(inputs) > 6 else input_mask token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask input_mask=input_mask,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
use_cache = inputs[9] if len(inputs) > 9 else use_cache inputs_embeds=inputs_embeds,
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[12] if len(inputs) > 12 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[13] if len(inputs) > 13 else labels return_dict=return_dict,
assert len(inputs) <= 14, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
mems = inputs.get("mems", mems) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_mask = (
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
mems, inputs["mems"],
perm_mask, inputs["perm_mask"],
target_mapping, inputs["target_mapping"],
flat_token_type_ids, flat_token_type_ids,
flat_input_mask, flat_input_mask,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
use_cache, inputs["use_cache"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
...@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.classifier(output) logits = self.classifier(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
...@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
...@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[13] if len(inputs) > 13 else start_positions input_ids=input_ids,
end_positions = inputs[14] if len(inputs) > 14 else end_positions
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
...@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
......
...@@ -42,10 +42,10 @@ from ...modeling_tf_utils import ( ...@@ -42,10 +42,10 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss, TFTokenClassificationLoss,
TFSequenceSummary, TFSequenceSummary,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
...@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
...@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.{{cookiecutter.lowercase_modelname}}(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
...@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
...@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + outputs[1:]
...@@ -862,18 +907,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -862,18 +907,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
logits = self.classifier(outputs[0]) logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
...@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
of the input tensors. (See :obj:`input_ids` above) of the input tensors. (See :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
outputs = self.{{cookiecutter.lowercase_modelname}}( outputs = self.{{cookiecutter.lowercase_modelname}}(
...@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
logits = self.sequence_summary(outputs[0]) logits = self.sequence_summary(outputs[0], training=inputs["training"])
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
...@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
...@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict( ...@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
@require_tf @require_tf
class TestTFBart(TFModelTesterMixin, unittest.TestCase): class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else () all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
model_tester_cls = TFBartModelTester
def setUp(self): def setUp(self):
self.model_tester = self.model_tester_cls(self) self.model_tester = TFBartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig) self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self): def test_config(self):
...@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase): ...@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported # inputs_embeds not supported
pass pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_saved_model_with_hidden_states_output(self): def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor # Should be uncommented during patrick TF refactor
pass pass
...@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase): ...@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
config, input_ids, batch_size = self._get_config_and_data() config, input_ids, batch_size = self._get_config_and_data()
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = TFBartForConditionalGeneration(config) lm_model = TFBartForConditionalGeneration(config)
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False) outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
...@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase): ...@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
lm_model = TFBartForConditionalGeneration(config) lm_model = TFBartForConditionalGeneration(config)
context = tf.fill((7, 2), 4) context = tf.fill((7, 2), 4)
summary = tf.fill((7, 7), 6) summary = tf.fill((7, 7), 6)
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False) outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
expected_shape = (*summary.shape, config.vocab_size) expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
......
...@@ -12,24 +12,23 @@ ...@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
from tests.test_configuration_common import ConfigTester from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available from transformers import (
BlenderbotConfig,
BlenderbotSmallTokenizer,
TFAutoModelForSeq2SeqLM,
TFBlenderbotForConditionalGeneration,
is_tf_available,
)
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available(): class TFBlenderbotModelTester(TFBartModelTester):
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class ModelTester(TFBartModelTester):
config_updates = dict( config_updates = dict(
normalize_before=True, normalize_before=True,
static_position_embeddings=True, static_position_embeddings=True,
...@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester): ...@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
@require_tf @require_tf
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase): class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
def setUp(self): def setUp(self):
self.model_tester = self.model_tester_cls(self) self.model_tester = TFBlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig) self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self): def test_config(self):
...@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase): ...@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor # Should be uncommented during patrick TF refactor
pass pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test @is_pt_tf_cross_test
@require_tokenizers @require_tokenizers
......
...@@ -152,7 +152,7 @@ class TFModelTesterMixin: ...@@ -152,7 +152,7 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
expected_arg_names = [ expected_arg_names = [
"inputs", "input_ids",
"attention_mask", "attention_mask",
"decoder_input_ids", "decoder_input_ids",
"decoder_attention_mask", "decoder_attention_mask",
...@@ -161,7 +161,7 @@ class TFModelTesterMixin: ...@@ -161,7 +161,7 @@ class TFModelTesterMixin:
self.assertListEqual(arg_names[:5], expected_arg_names) self.assertListEqual(arg_names[:5], expected_arg_names)
else: else:
expected_arg_names = ["inputs"] expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
@slow @slow
...@@ -753,7 +753,7 @@ class TFModelTesterMixin: ...@@ -753,7 +753,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"]
# iterate over all generative models # iterate over all generative models
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment