Unverified Commit edfd82f5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Change model outputs types to self-document outputs (#5438)

* [WIP] Proposal for model outputs

* All Bert models

* Make CI green maybe?

* Fix ONNX test

* Isolate ModelOutput from pt and tf

* Formatting

* Add Electra models

* Auto-generate docstrings from outputs

* Add TF outputs

* Add some BERT models

* Revert TF side

* Remove last traces of TF changes

* Fail with a clear error message

* Add Albert and work through Bart

* Add CTRL and DistilBert

* Formatting

* Progress on Bart

* Renames and finish Bart

* Formatting

* Fix last test

* Add DPR

* Finish Electra and add FlauBERT

* Add GPT2

* Add Longformer

* Add MMBT

* Add MobileBert

* Add GPT

* Formatting

* Add Reformer

* Add Roberta

* Add T5

* Add Transformer XL

* Fix test

* Add XLM + fix XLMForTokenClassification

* Style + XLMRoberta

* Add XLNet

* Formatting

* Add doc of return_tuple arg
parent fa265230
...@@ -386,6 +386,9 @@ def start_memory_tracing( ...@@ -386,6 +386,9 @@ def start_memory_tracing(
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace: elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
return traceit return traceit
if "__name__" not in frame.f_globals:
return traceit
# Filter modules # Filter modules
name = frame.f_globals["__name__"] name = frame.f_globals["__name__"]
if not isinstance(name, str): if not isinstance(name, str):
......
...@@ -49,6 +49,8 @@ class PretrainedConfig(object): ...@@ -49,6 +49,8 @@ class PretrainedConfig(object):
Whether or not the model should returns all attentions. Whether or not the model should returns all attentions.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`): is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as an encoder/decoder or not. Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`): is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -131,6 +133,7 @@ class PretrainedConfig(object): ...@@ -131,6 +133,7 @@ class PretrainedConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
self.return_tuple = kwargs.pop("return_tuple", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False) self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
...@@ -190,6 +193,11 @@ class PretrainedConfig(object): ...@@ -190,6 +193,11 @@ class PretrainedConfig(object):
logger.error("Can't set {} with value {} for {}".format(key, value, self)) logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err raise err
@property
def use_return_tuple(self):
# If torchscript is set, force return_tuple to avoid jit errors
return self.return_tuple or self.torchscript
@property @property
def num_labels(self) -> int: def num_labels(self) -> int:
return len(self.id2label) return len(self.id2label)
......
...@@ -4,6 +4,7 @@ from os.path import abspath, dirname, exists ...@@ -4,6 +4,7 @@ from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput
from transformers.pipelines import Pipeline, pipeline from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
...@@ -89,7 +90,8 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D ...@@ -89,7 +90,8 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework) tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
seq_len = tokens.input_ids.shape[-1] seq_len = tokens.input_ids.shape[-1]
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens) outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
if isinstance(outputs, ModelOutput):
outputs = outputs.to_tuple()
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
outputs = (outputs,) outputs = (outputs,)
......
...@@ -8,6 +8,7 @@ import fnmatch ...@@ -8,6 +8,7 @@ import fnmatch
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import sys import sys
import tarfile import tarfile
...@@ -186,6 +187,31 @@ def add_end_docstrings(*docstr): ...@@ -186,6 +187,31 @@ def add_end_docstrings(*docstr):
return docstring_decorator return docstring_decorator
RETURN_INTRODUCTION = r"""
Returns:
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
"""
def _prepare_output_docstrings(output_type, config_class):
"""
Prepares the return part of the docstring using `output_type`.
"""
docstrings = output_type.__doc__
# Remove the head of the docstring to keep the list of args only
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
docstrings = "\n".join(lines[(i + 1) :])
# Add the return introduction
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
return intro + docstrings
PT_TOKEN_CLASSIFICATION_SAMPLE = r""" PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example:: Example::
...@@ -414,7 +440,7 @@ TF_CAUSAL_LM_SAMPLE = r""" ...@@ -414,7 +440,7 @@ TF_CAUSAL_LM_SAMPLE = r"""
""" """
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None): def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None):
def docstring_decorator(fn): def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0] model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF" is_tf_class = model_class[:2] == "TF"
...@@ -436,8 +462,29 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None): ...@@ -436,8 +462,29 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
else: else:
raise ValueError(f"Docstring can't be built for model {model_class}") raise ValueError(f"Docstring can't be built for model {model_class}")
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
docstrings = fn.__doc__
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
lines[i] = _prepare_output_docstrings(output_type, config_class)
docstrings = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
)
fn.__doc__ = docstrings
return fn return fn
return docstring_decorator return docstring_decorator
...@@ -806,3 +853,22 @@ def tf_required(func): ...@@ -806,3 +853,22 @@ def tf_required(func):
raise ImportError(f"Method `{func.__name__}` requires TF.") raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper return wrapper
class ModelOutput:
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
will ignore ``None`` in the attributes.
"""
def to_tuple(self):
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
def to_dict(self):
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
def __getitem__(self, i):
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
def __len__(self):
return len(self.to_tuple())
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -53,6 +53,10 @@ CAMEMBERT_START_DOCSTRING = r""" ...@@ -53,6 +53,10 @@ CAMEMBERT_START_DOCSTRING = r"""
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
""" """
......
...@@ -25,11 +25,13 @@ from torch.nn import CrossEntropyLoss ...@@ -25,11 +25,13 @@ from torch.nn import CrossEntropyLoss
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "CTRLConfig"
_TOKENIZER_FOR_DOC = "CTRLTokenizer" _TOKENIZER_FOR_DOC = "CTRLTokenizer"
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [ CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -288,6 +290,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -288,6 +290,10 @@ CTRL_INPUTS_DOCSTRING = r"""
can be used to speed up decoding (see `past`). Defaults to `True`. can be used to speed up decoding (see `past`). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
""" """
...@@ -328,7 +334,12 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -328,7 +334,12 @@ class CTRLModel(CTRLPreTrainedModel):
self.h[layer].multi_head_attention.prune_heads(heads) self.h[layer].multi_head_attention.prune_heads(heads)
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl") @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -341,32 +352,14 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -341,32 +352,14 @@ class CTRLModel(CTRLPreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
): ):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -435,9 +428,9 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -435,9 +428,9 @@ class CTRLModel(CTRLPreTrainedModel):
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (inputs_embeds.size(-1),) output_shape = input_shape + (inputs_embeds.size(-1),)
presents = () presents = () if use_cache else None
all_hidden_states = () all_hidden_states = () if output_hidden_states else None
all_attentions = [] all_attentions = [] if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
...@@ -462,17 +455,20 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -462,17 +455,20 @@ class CTRLModel(CTRLPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs if return_tuple:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@add_start_docstrings( @add_start_docstrings(
...@@ -499,7 +495,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -499,7 +495,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl") @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -513,6 +514,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -513,6 +514,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -521,28 +523,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -521,28 +523,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]`` Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
""" """
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past=past, past=past,
...@@ -554,14 +537,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -554,14 +537,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:] loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -569,6 +552,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -569,6 +552,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) if return_tuple:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
This diff is collapsed.
...@@ -16,19 +16,23 @@ ...@@ -16,19 +16,23 @@
import logging import logging
from typing import Optional, Tuple from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from .configuration_dpr import DPRConfig from .configuration_dpr import DPRConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_bert import BertModel from .modeling_bert import BertModel
from .modeling_outputs import BaseModelOutputWithPooling
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "DPRConfig"
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [ DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-ctx_encoder-single-nq-base", "facebook/dpr-ctx_encoder-single-nq-base",
] ]
...@@ -40,6 +44,102 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -40,6 +44,102 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
##########
# Outputs
##########
@dataclass
class DPRContextEncoderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the context representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed contexts for
nearest neighbors queries with questions embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DPRQuestionEncoderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the question representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed questions for
nearest neighbors queries with context embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DPRReaderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
start_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the start index of the span for each passage.
end_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the end index of the span for each passage.
relevance_logits: (:obj:`torch.FloatTensor`` of shape ``(n_passages, )``):
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage
to answer the question, compared to all the other passages.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: torch.FloatTensor
end_logits: torch.FloatTensor
relevance_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class DPREncoder(PreTrainedModel): class DPREncoder(PreTrainedModel):
base_model_prefix = "bert_model" base_model_prefix = "bert_model"
...@@ -61,28 +161,31 @@ class DPREncoder(PreTrainedModel): ...@@ -61,28 +161,31 @@ class DPREncoder(PreTrainedModel):
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
) -> Tuple[Tensor, ...]: return_tuple: bool = False,
) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
outputs = self.bert_model( outputs = self.bert_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_hidden_states=True,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
sequence_output, pooled_output, hidden_states = outputs[:3] sequence_output, pooled_output = outputs[:2]
pooled_output = sequence_output[:, 0, :] pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0: if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output) pooled_output = self.encode_proj(pooled_output)
dpr_encoder_outputs = (sequence_output, pooled_output) if return_tuple:
return (sequence_output, pooled_output) + outputs[2:]
if output_hidden_states:
dpr_encoder_outputs += (hidden_states,)
if output_attentions:
dpr_encoder_outputs += (outputs[-1],)
return dpr_encoder_outputs return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@property @property
def embeddings_size(self) -> int: def embeddings_size(self) -> int:
...@@ -114,7 +217,8 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -114,7 +217,8 @@ class DPRSpanPredictor(PreTrainedModel):
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
): return_tuple: bool = False,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2] n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
# feed encoder # feed encoder
...@@ -124,6 +228,7 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -124,6 +228,7 @@ class DPRSpanPredictor(PreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -133,12 +238,22 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -133,12 +238,22 @@ class DPRSpanPredictor(PreTrainedModel):
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
# resize and return
return ( # resize
start_logits.view(n_passages, sequence_length), start_logits = start_logits.view(n_passages, sequence_length)
end_logits.view(n_passages, sequence_length), end_logits = end_logits.view(n_passages, sequence_length)
relevance_logits.view(n_passages), relevance_logits = relevance_logits.view(n_passages)
) + outputs[2:]
if return_tuple:
return (start_logits, end_logits, relevance_logits) + outputs[2:]
return DPRReaderOutput(
start_logits=start_logits,
end_logits=end_logits,
relevance_logits=relevance_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def init_weights(self): def init_weights(self):
self.encoder.init_weights() self.encoder.init_weights()
...@@ -288,6 +403,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder): ...@@ -288,6 +403,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids: Optional[Tensor] = None, input_ids: Optional[Tensor] = None,
...@@ -296,26 +412,10 @@ class DPRContextEncoder(DPRPretrainedContextEncoder): ...@@ -296,26 +412,10 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
) -> Tensor: return_tuple=None,
) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the context representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed contexts for
nearest neighbors queries with questions embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples:: Examples::
...@@ -331,6 +431,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder): ...@@ -331,6 +431,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -359,9 +460,14 @@ class DPRContextEncoder(DPRPretrainedContextEncoder): ...@@ -359,9 +460,14 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
if return_tuple:
return outputs[1:]
return DPRContextEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
sequence_output, pooled_output = outputs[:2]
return (pooled_output,) + outputs[2:]
@add_start_docstrings( @add_start_docstrings(
...@@ -376,6 +482,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -376,6 +482,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids: Optional[Tensor] = None, input_ids: Optional[Tensor] = None,
...@@ -384,26 +491,10 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -384,26 +491,10 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
) -> Tensor: return_tuple=None,
) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the question representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed questions for
nearest neighbors queries with context embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples:: Examples::
...@@ -417,6 +508,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -417,6 +508,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -445,9 +537,14 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -445,9 +537,14 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
if return_tuple:
return outputs[1:]
return DPRQuestionEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
sequence_output, pooled_output = outputs[:2]
return (pooled_output,) + outputs[2:]
@add_start_docstrings( @add_start_docstrings(
...@@ -461,6 +558,7 @@ class DPRReader(DPRPretrainedReader): ...@@ -461,6 +558,7 @@ class DPRReader(DPRPretrainedReader):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(DPR_READER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids: Optional[Tensor] = None, input_ids: Optional[Tensor] = None,
...@@ -468,30 +566,10 @@ class DPRReader(DPRPretrainedReader): ...@@ -468,30 +566,10 @@ class DPRReader(DPRPretrainedReader):
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None, output_attentions: bool = None,
output_hidden_states: bool = None, output_hidden_states: bool = None,
) -> Tuple[Tensor, ...]: return_tuple=None,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
input_ids: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``)
They correspond to the combined `input_ids` from `(question + context title + context content`).
start_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the start index of the span for each passage.
end_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the end index of the span for each passage.
relevance_logits: (:obj:`torch.FloatTensor`` of shape ``(n_passages, )``):
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage
to answer the question, compared to all the other passages.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples:: Examples::
...@@ -514,6 +592,7 @@ class DPRReader(DPRPretrainedReader): ...@@ -514,6 +592,7 @@ class DPRReader(DPRPretrainedReader):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -529,13 +608,11 @@ class DPRReader(DPRPretrainedReader): ...@@ -529,13 +608,11 @@ class DPRReader(DPRPretrainedReader):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
span_outputs = self.span_predictor( return self.span_predictor(
input_ids, input_ids,
attention_mask, attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
start_logits, end_logits, relevance_logits = span_outputs[:3]
return (start_logits, end_logits, relevance_logits) + span_outputs[3:]
This diff is collapsed.
...@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
head_mask=head_mask, head_mask=head_mask,
return_tuple=True,
**kwargs_encoder, **kwargs_encoder,
) )
...@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
labels=labels, labels=labels,
return_tuple=True,
**kwargs_decoder, **kwargs_decoder,
) )
......
...@@ -23,6 +23,7 @@ from torch.nn import functional as F ...@@ -23,6 +23,7 @@ from torch.nn import functional as F
from .configuration_flaubert import FlaubertConfig from .configuration_flaubert import FlaubertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutput
from .modeling_xlm import ( from .modeling_xlm import (
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
...@@ -35,6 +36,7 @@ from .modeling_xlm import ( ...@@ -35,6 +36,7 @@ from .modeling_xlm import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "FlaubertConfig"
_TOKENIZER_FOR_DOC = "FlaubertTokenizer" _TOKENIZER_FOR_DOC = "FlaubertTokenizer"
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -104,6 +106,10 @@ FLAUBERT_INPUTS_DOCSTRING = r""" ...@@ -104,6 +106,10 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
""" """
...@@ -121,7 +127,12 @@ class FlaubertModel(XLMModel): ...@@ -121,7 +127,12 @@ class FlaubertModel(XLMModel):
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="flaubert/flaubert_base_cased") @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="flaubert/flaubert_base_cased",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -135,28 +146,13 @@ class FlaubertModel(XLMModel): ...@@ -135,28 +146,13 @@ class FlaubertModel(XLMModel):
inputs_embeds=None, inputs_embeds=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
): ):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if input_ids is not None: if input_ids is not None:
...@@ -227,8 +223,8 @@ class FlaubertModel(XLMModel): ...@@ -227,8 +223,8 @@ class FlaubertModel(XLMModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers # transformer layers
hidden_states = () hidden_states = () if output_hidden_states else None
attentions = () attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
# LayerDrop # LayerDrop
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
...@@ -286,12 +282,10 @@ class FlaubertModel(XLMModel): ...@@ -286,12 +282,10 @@ class FlaubertModel(XLMModel):
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
outputs = (tensor,) if return_tuple:
if output_hidden_states: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
outputs = outputs + (hidden_states,)
if output_attentions: return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
import logging import logging
import os import os
import warnings import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -26,7 +28,14 @@ from torch.nn import CrossEntropyLoss ...@@ -26,7 +28,14 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import ( from .modeling_utils import (
Conv1D, Conv1D,
PreTrainedModel, PreTrainedModel,
...@@ -38,6 +47,7 @@ from .modeling_utils import ( ...@@ -38,6 +47,7 @@ from .modeling_utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "GPT2Config"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer" _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -280,6 +290,48 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -280,6 +290,48 @@ class GPT2PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
lm_loss: Optional[torch.FloatTensor]
mc_loss: Optional[torch.FloatTensor]
lm_logits: torch.FloatTensor
mc_logits: torch.FloatTensor
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
GPT2_START_DOCSTRING = r""" GPT2_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...@@ -339,6 +391,10 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -339,6 +391,10 @@ GPT2_INPUTS_DOCSTRING = r"""
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`. If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
""" """
...@@ -372,7 +428,12 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -372,7 +428,12 @@ class GPT2Model(GPT2PreTrainedModel):
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2") @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -385,33 +446,14 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -385,33 +446,14 @@ class GPT2Model(GPT2PreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
): ):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``) is passed or when ``config.output_hidden_states=True``:
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -477,9 +519,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -477,9 +519,9 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = () presents = () if use_cache else None
all_attentions = [] all_attentions = () if output_attentions else None
all_hidden_states = () all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
...@@ -498,7 +540,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -498,7 +540,7 @@ class GPT2Model(GPT2PreTrainedModel):
presents = presents + (present,) presents = presents + (present,)
if output_attentions: if output_attentions:
all_attentions.append(outputs[2]) all_attentions = all_attentions + (outputs[2],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -507,17 +549,15 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -507,17 +549,15 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) if return_tuple:
if use_cache is True: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
outputs = outputs + (presents,)
if output_hidden_states: return BaseModelOutputWithPast(
outputs = outputs + (all_hidden_states,) last_hidden_state=hidden_states,
if output_attentions: past_key_values=presents,
# let the number of heads free (-1) so we can extract attention even after head pruning hidden_states=all_hidden_states,
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attentions=all_attentions,
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) )
outputs = outputs + (all_attentions,)
return outputs # last hidden state, (presents), (all hidden_states), (attentions)
@add_start_docstrings( @add_start_docstrings(
...@@ -544,7 +584,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -544,7 +584,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2") @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -558,6 +603,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -558,6 +603,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -566,28 +612,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -566,28 +612,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]`` Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
""" """
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past=past, past=past,
...@@ -599,12 +626,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -599,12 +626,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:] loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -612,9 +640,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -612,9 +640,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) if return_tuple:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings( @add_start_docstrings(
...@@ -639,6 +676,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -639,6 +676,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
return self.lm_head return self.lm_head
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -654,6 +692,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -654,6 +692,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -674,29 +713,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -674,29 +713,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples:: Examples::
...@@ -729,6 +745,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -729,6 +745,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
) )
labels = kwargs.pop("lm_labels") labels = kwargs.pop("lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
...@@ -741,6 +758,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -741,6 +758,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -748,16 +766,29 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -748,16 +766,29 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:] mc_loss = None
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
outputs = (loss,) + outputs lm_loss = None
if labels is not None: if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
if return_tuple:
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions) output = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
lm_loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment