"docs/vscode:/vscode.git/clone" did not exist on "c4cc894086ba86fefbd265f9a80fc8220d2ee182"
Unverified Commit edfd82f5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Change model outputs types to self-document outputs (#5438)

* [WIP] Proposal for model outputs

* All Bert models

* Make CI green maybe?

* Fix ONNX test

* Isolate ModelOutput from pt and tf

* Formatting

* Add Electra models

* Auto-generate docstrings from outputs

* Add TF outputs

* Add some BERT models

* Revert TF side

* Remove last traces of TF changes

* Fail with a clear error message

* Add Albert and work through Bart

* Add CTRL and DistilBert

* Formatting

* Progress on Bart

* Renames and finish Bart

* Formatting

* Fix last test

* Add DPR

* Finish Electra and add FlauBERT

* Add GPT2

* Add Longformer

* Add MMBT

* Add MobileBert

* Add GPT

* Formatting

* Add Reformer

* Add Roberta

* Add T5

* Add Transformer XL

* Fix test

* Add XLM + fix XLMForTokenClassification

* Style + XLMRoberta

* Add XLNet

* Formatting

* Add doc of return_tuple arg
parent fa265230
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import inspect import inspect
import logging import logging
import os import os
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -31,6 +32,7 @@ from .file_utils import ( ...@@ -31,6 +32,7 @@ from .file_utils import (
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME, TF_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ModelOutput,
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_remote_url, is_remote_url,
...@@ -941,6 +943,35 @@ class PoolerAnswerClass(nn.Module): ...@@ -941,6 +943,35 @@ class PoolerAnswerClass(nn.Module):
return x return x
@dataclass
class SquadHeadOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :obj:`SquadHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
class SQuADHead(nn.Module): class SQuADHead(nn.Module):
r""" A SQuAD head inspired by XLNet. r""" A SQuAD head inspired by XLNet.
...@@ -992,10 +1023,15 @@ class SQuADHead(nn.Module): ...@@ -992,10 +1023,15 @@ class SQuADHead(nn.Module):
self.answer_class = PoolerAnswerClass(config) self.answer_class = PoolerAnswerClass(config)
def forward( def forward(
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, self,
hidden_states,
start_positions=None,
end_positions=None,
cls_index=None,
is_impossible=None,
p_mask=None,
return_tuple=False,
): ):
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask=p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
...@@ -1021,7 +1057,7 @@ class SQuADHead(nn.Module): ...@@ -1021,7 +1057,7 @@ class SQuADHead(nn.Module):
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5 total_loss += cls_loss * 0.5
outputs = (total_loss,) + outputs return (total_loss,) if return_tuple else SquadHeadOutput(loss=total_loss)
else: else:
# during inference, compute the end logits based on beam search # during inference, compute the end logits based on beam search
...@@ -1051,11 +1087,16 @@ class SQuADHead(nn.Module): ...@@ -1051,11 +1087,16 @@ class SQuADHead(nn.Module):
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs if return_tuple:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits else:
# or (if labels are provided) (total_loss,) return SquadHeadOutput(
return outputs start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
class SequenceSummary(nn.Module): class SequenceSummary(nn.Module):
......
This diff is collapsed.
...@@ -55,6 +55,10 @@ XLM_ROBERTA_START_DOCSTRING = r""" ...@@ -55,6 +55,10 @@ XLM_ROBERTA_START_DOCSTRING = r"""
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
""" """
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment