"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "870fae62e7a784fcc065382e31f173cc048da8fe"
Unverified Commit c67d1a02 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tf model outputs (#6247)

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* Add new models and fix issues

* Quality improvements

* Add T5

* A bit of cleanup

* Fix for slow tests

* Style
parent bd0eab35
......@@ -50,7 +50,10 @@ AlbertTokenizer
Albert specific outputs
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_albert.AlbertForPretrainingOutput
.. autoclass:: transformers.modeling_albert.AlbertForPreTrainingOutput
:members:
.. autoclass:: transformers.modeling_tf_albert.TFAlbertForPreTrainingOutput
:members:
......
......@@ -57,7 +57,10 @@ BertTokenizerFast
Bert specific outputs
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_bert.BertForPretrainingOutput
.. autoclass:: transformers.modeling_bert.BertForPreTrainingOutput
:members:
.. autoclass:: transformers.modeling_tf_bert.TFBertForPreTrainingOutput
:members:
......
......@@ -74,7 +74,10 @@ ElectraTokenizerFast
Electra specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_electra.ElectraForPretrainingOutput
.. autoclass:: transformers.modeling_electra.ElectraForPreTrainingOutput
:members:
.. autoclass:: transformers.modeling_tf_electra.TFElectraForPreTrainingOutput
:members:
......@@ -106,6 +109,13 @@ ElectraForSequenceClassification
:members:
ElectraForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ElectraForMultipleChoice
:members:
ElectraForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -141,6 +151,20 @@ TFElectraForMaskedLM
:members:
TFElectraForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFElectraForSequenceClassification
:members:
TFElectraForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFElectraForMultipleChoice
:members:
TFElectraForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -77,6 +77,9 @@ OpenAI specific outputs
.. autoclass:: transformers.modeling_openai.OpenAIGPTDoubleHeadsModelOutput
:members:
.. autoclass:: transformers.modeling_tf_openai.TFOpenAIGPTDoubleHeadsModelOutput
:members:
OpenAIGPTModel
~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -64,6 +64,9 @@ GPT2 specific outputs
.. autoclass:: transformers.modeling_gpt2.GPT2DoubleHeadsModelOutput
:members:
.. autoclass:: transformers.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput
:members:
GPT2Model
~~~~~~~~~~~~~~~~~~~~~
......
......@@ -59,7 +59,10 @@ MobileBertTokenizerFast
MobileBert specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_mobilebert.MobileBertForPretrainingOutput
.. autoclass:: transformers.modeling_mobilebert.MobileBertForPreTrainingOutput
:members:
.. autoclass:: transformers.modeling_tf_mobilebert.TFMobileBertForPreTrainingOutput
:members:
......
......@@ -63,6 +63,12 @@ TransfoXL specific outputs
.. autoclass:: transformers.modeling_transfo_xl.TransfoXLLMHeadModelOutput
:members:
.. autoclass:: transformers.modeling_tf_transfo_xl.TFTransfoXLModelOutput
:members:
.. autoclass:: transformers.modeling_tf_transfo_xl.TFTransfoXLLMHeadModelOutput
:members:
TransfoXLModel
~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -74,6 +74,24 @@ XLNet specific outputs
.. autoclass:: transformers.modeling_xlnet.XLNetForQuestionAnsweringOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetModelOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetLMHeadModelOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetForSequenceClassificationOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetForMultipleChoiceOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetForTokenClassificationOutput
:members:
.. autoclass:: transformers.modeling_tf_xlnet.TFXLNetForQuestionAnsweringSimpleOutput
:members:
XLNetModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -190,7 +190,7 @@ def add_end_docstrings(*docstr):
return docstring_decorator
RETURN_INTRODUCTION = r"""
PT_RETURN_INTRODUCTION = r"""
Returns:
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`:
A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a
......@@ -200,6 +200,16 @@ RETURN_INTRODUCTION = r"""
"""
TF_RETURN_INTRODUCTION = r"""
Returns:
:class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`:
A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a
tuple of :obj:`tf.Tensor` comprising various elements depending on the configuration
(:class:`~transformers.{config_class}`) and inputs.
"""
def _get_indent(t):
"""Returns the indentation in the first line of t"""
search = re.search(r"^(\s*)\S", t)
......@@ -249,7 +259,8 @@ def _prepare_output_docstrings(output_type, config_class):
# Add the return introduction
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class)
intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
return intro + docstrings
......
......@@ -407,9 +407,9 @@ class AlbertPreTrainedModel(PreTrainedModel):
@dataclass
class AlbertForPretrainingOutput(ModelOutput):
class AlbertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.AlbertForPretrainingModel`.
Output type of :class:`~transformers.AlbertForPreTrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
......@@ -643,7 +643,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
return self.predictions.decoder
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AlbertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -728,7 +728,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
output = (prediction_scores, sop_scores) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return AlbertForPretrainingOutput(
return AlbertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
sop_logits=sop_scores,
......
......@@ -586,9 +586,9 @@ class BertPreTrainedModel(PreTrainedModel):
@dataclass
class BertForPretrainingOutput(ModelOutput):
class BertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPretrainingModel`.
Output type of :class:`~transformers.BertForPreTrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
......@@ -837,7 +837,7 @@ class BertForPreTraining(BertPreTrainedModel):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -918,7 +918,7 @@ class BertForPreTraining(BertPreTrainedModel):
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return BertForPretrainingOutput(
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
......
......@@ -188,9 +188,9 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
@dataclass
class ElectraForPretrainingOutput(ModelOutput):
class ElectraForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.ElectraForPretrainingModel`.
Output type of :class:`~transformers.ElectraForPreTrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
......@@ -496,7 +496,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ElectraForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=ElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -562,7 +562,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
return ElectraForPretrainingOutput(
return ElectraForPreTrainingOutput(
loss=loss,
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
......@@ -850,7 +850,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
@add_start_docstrings(
"""ELECTRA Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ELECTRA_INPUTS_DOCSTRING,
ELECTRA_START_DOCSTRING,
)
class ElectraForMultipleChoice(ElectraPreTrainedModel):
def __init__(self, config):
......
......@@ -685,9 +685,9 @@ class MobileBertPreTrainedModel(PreTrainedModel):
@dataclass
class MobileBertForPretrainingOutput(ModelOutput):
class MobileBertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.MobileBertForPretrainingModel`.
Output type of :class:`~transformers.MobileBertForPreTrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
......@@ -948,7 +948,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MobileBertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -1018,7 +1018,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return MobileBertForPretrainingOutput(
return MobileBertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
......
......@@ -973,7 +973,7 @@ class T5Model(T5PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif not return_dict and not isinstance(encoder_outputs, BaseModelOutput):
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
......
This diff is collapsed.
This diff is collapsed.
......@@ -62,8 +62,6 @@ CAMEMBERT_START_DOCSTRING = r"""
config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
"""
......
......@@ -23,6 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
from .modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
......@@ -35,7 +36,8 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
_TOKENIZER_FOR_DOC = "CtrlTokenizer"
_CONFIG_FOR_DOC = "CTRLConfig"
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ctrl"
......@@ -207,6 +209,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.return_dict = config.use_return_dict
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
......@@ -260,6 +263,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
......@@ -274,7 +278,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
assert len(inputs) <= 10, "Too many inputs."
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -286,13 +291,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
# If using past key value states, only the last tokens
# should be given as an input
......@@ -374,9 +381,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = ()
all_hidden_states = ()
all_attentions = []
presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
......@@ -396,24 +403,27 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
presents = presents + (present,)
if output_attentions:
all_attentions.append(outputs[2])
all_attentions = all_attentions + (outputs[2],)
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (presents,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class TFCTRLPreTrainedModel(TFPreTrainedModel):
......@@ -503,6 +513,11 @@ CTRL_INPUTS_DOCSTRING = r"""
(if set to :obj:`False`) for evaluation.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
plain tuple.
"""
......@@ -516,29 +531,13 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
self.transformer = TFCTRLMainLayer(config, name="transformer")
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
outputs = self.transformer(inputs, **kwargs)
return outputs
......@@ -585,7 +584,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=TFCausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs,
......@@ -598,6 +602,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
):
......@@ -605,31 +610,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the cross entropy classification loss.
Indices should be in ``[0, ..., config.vocab_size - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[10] if len(inputs) > 10 else labels
if len(inputs) > 10:
inputs = inputs[:10]
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
......@@ -644,6 +630,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
......@@ -651,12 +638,21 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
logits = self.lm_head(hidden_states)
outputs = (logits,) + transformer_outputs[1:]
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # lm_logits, presents, (all hidden_states), (attentions)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment