Unverified Commit 04ab2ca6 authored by Vasudev Gupta's avatar Vasudev Gupta Committed by GitHub
Browse files

add pooling layer support (#11439)

parent 30f06589
......@@ -41,7 +41,6 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
......@@ -1857,6 +1856,41 @@ class BigBirdForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BigBirdForQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 1)`):
pooler output from BigBigModel
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings(
"The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.",
BIG_BIRD_START_DOCSTRING,
......@@ -2852,14 +2886,14 @@ class BigBirdForQuestionAnsweringHead(nn.Module):
BIG_BIRD_START_DOCSTRING,
)
class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
def __init__(self, config):
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.sep_token_id = config.sep_token_id
self.bert = BigBirdModel(config, add_pooling_layer=False)
self.bert = BigBirdModel(config, add_pooling_layer=add_pooling_layer)
self.qa_classifier = BigBirdForQuestionAnsweringHead(config)
self.init_weights()
......@@ -2868,7 +2902,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bigbird-base-trivia-itc",
output_type=QuestionAnsweringModelOutput,
output_type=BigBirdForQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
......@@ -2958,10 +2992,11 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
return BigBirdForQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment