"...lm-evaluation-harness.git" did not exist on "5f1d18d49e01f1349e323e3e291dd5b3642a5ce3"
Unverified Commit 29d49924 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
parent 82d443a7
......@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_xlm import XLMConfig
......@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
**kwargs,
):
# removed: src_enc=None, src_len=None
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
bs, slen = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None:
if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
if inputs["lengths"] is None:
if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index
# check inputs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
......@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
......@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layers
inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - inputs["cache"]["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if inputs["langs"] is not None:
inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids)
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(inputs["langs"])
if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training)
tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis]
# transformer layers
......@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self attention
attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
......@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,)
# update cache length
if cache is not None:
cache["slen"] += tensor.size(1)
if inputs["cache"] is not None:
inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
......@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
return outputs
......@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
output_type=TFXLMWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
transformer_outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
......@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
......@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
......@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
if lengths is not None:
if inputs["lengths"] is not None:
logger.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.",
)
lengths = None
inputs["lengths"] = None
transformer_outputs = self.transformer(
flat_input_ids,
......@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_langs,
flat_token_type_ids,
flat_position_ids,
lengths,
cache,
head_mask,
inputs["lengths"],
inputs["cache"],
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:]
......@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
......@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
......@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
......@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[12] if len(inputs) > 12 else start_positions
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
......@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
......@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
......
......@@ -17,7 +17,6 @@
TF 2.0 XLNet model.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple
......@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_xlnet import XLNetConfig
......@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)[:2]
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(inputs["input_ids"])[:2]
elif inputs["inputs_embeds"] is not None:
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
inputs["token_type_ids"] = (
tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
)
inputs["input_mask"] = (
tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
)
inputs["attention_mask"] = (
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
)
inputs["perm_mask"] = (
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
)
inputs["target_mapping"] = (
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
)
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
......@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, (
assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
data_mask = inputs["input_mask"][None]
elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
data_mask = inputs["perm_mask"]
else:
data_mask = None
......@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
if inputs["inputs_embeds"] is not None:
word_emb_k = inputs["inputs_embeds"]
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
word_emb_k = self.word_embedding(inputs["input_ids"])
output_h = self.dropout(word_emb_k, training=inputs["training"])
if inputs["target_mapping"] is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=training)
output_g = self.dropout(word_emb_q, training=inputs["training"])
else:
output_g = None
# Segment embedding
if token_type_ids is not None:
if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else:
cat_ids = token_type_ids
cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
else:
seg_mat = None
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training)
pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
inputs["head_mask"] = [None] * self.n_layer
new_mems = ()
if mems is None:
mems = [None] * len(self.layer)
if inputs["mems"] is None:
inputs["mems"] = [None] * len(self.layer)
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
# cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
......@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
inputs["mems"][i],
inputs["target_mapping"],
inputs["head_mask"][i],
output_attentions,
training=training,
training=inputs["training"],
)
output_h, output_g = outputs[:2]
if output_attentions:
......@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training)
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2))
......@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_type=TFXLNetModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {
"inputs": inputs,
"input_ids": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
......@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state, training=training)
logits = self.lm_loss(hidden_state, training=inputs["training"])
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
......@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
......@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs=None,
input_ids=None,
token_type_ids=None,
input_mask=None,
attention_mask=None,
......@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
labels = inputs[13] if len(inputs) > 13 else labels
assert len(inputs) <= 14, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_mask = (
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
transformer_outputs = self.transformer(
flat_input_ids,
flat_attention_mask,
mems,
perm_mask,
target_mapping,
inputs["mems"],
inputs["perm_mask"],
inputs["target_mapping"],
flat_token_type_ids,
flat_input_mask,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
inputs["use_cache"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:]
......@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.classifier(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
......@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
......@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[13] if len(inputs) > 13 else start_positions
end_positions = inputs[14] if len(inputs) > 14 else end_positions
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
......@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
......@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
......
......@@ -42,10 +42,10 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss,
TFSequenceSummary,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["head_mask"] = [None] * self.num_hidden_layers
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
......@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.{{cookiecutter.lowercase_modelname}}(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[1:]
......@@ -862,18 +907,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
logits = self.classifier(outputs[0])
loss = None if labels is None else self.compute_loss(labels, logits)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
......@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
of the input tensors. (See :obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
......@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
logits = self.sequence_summary(outputs[0])
logits = self.sequence_summary(outputs[0], training=inputs["training"])
logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
......@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
......@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
......@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
......
......@@ -14,7 +14,6 @@
# limitations under the License.
import tempfile
import unittest
import numpy as np
......@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
@require_tf
class TestTFBart(TFModelTesterMixin, unittest.TestCase):
class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
model_tester_cls = TFBartModelTester
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.model_tester = TFBartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self):
......@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
......@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
config, input_ids, batch_size = self._get_config_and_data()
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = TFBartForConditionalGeneration(config)
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape)
......@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
lm_model = TFBartForConditionalGeneration(config)
context = tf.fill((7, 2), 4)
summary = tf.fill((7, 7), 6)
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False)
outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape)
......
......@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers import (
BlenderbotConfig,
BlenderbotSmallTokenizer,
TFAutoModelForSeq2SeqLM,
TFBlenderbotForConditionalGeneration,
is_tf_available,
)
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class ModelTester(TFBartModelTester):
class TFBlenderbotModelTester(TFBartModelTester):
config_updates = dict(
normalize_before=True,
static_position_embeddings=True,
......@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
@require_tf
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.model_tester = TFBlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self):
......@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test
@require_tokenizers
......
......@@ -152,7 +152,7 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder:
expected_arg_names = [
"inputs",
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
......@@ -161,7 +161,7 @@ class TFModelTesterMixin:
self.assertListEqual(arg_names[:5], expected_arg_names)
else:
expected_arg_names = ["inputs"]
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@slow
......@@ -753,7 +753,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
input_ids = inputs_dict["input_ids"]
# iterate over all generative models
for model_class in self.all_generative_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment