Unverified Commit af69360b authored by APAVOU Clément's avatar APAVOU Clément Committed by GitHub
Browse files

Add `OPTForQuestionAnswering` (#19402)

* Add `OPTForQuestionAnswering`

- added `OPTForQuestionAnswering` class based on `BloomForQuestionAnswering`
- added `OPTForQuestionAnswering` in common tests
- all common tests pass
- make fixup done

* added docstrings for OPTForQuestionAnswering

* Fix docstrings for OPTForQuestionAnswering
parent ba71bf4c
...@@ -59,6 +59,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase ...@@ -59,6 +59,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase
[[autodoc]] OPTForSequenceClassification [[autodoc]] OPTForSequenceClassification
- forward - forward
## OPTForQuestionAnswering
[[autodoc]] OPTForQuestionAnswering
- forward
## FlaxOPTModel ## FlaxOPTModel
[[autodoc]] FlaxOPTModel [[autodoc]] FlaxOPTModel
......
...@@ -1661,6 +1661,7 @@ else: ...@@ -1661,6 +1661,7 @@ else:
"OPTModel", "OPTModel",
"OPTPreTrainedModel", "OPTPreTrainedModel",
"OPTForSequenceClassification", "OPTForSequenceClassification",
"OPTForQuestionAnswering",
] ]
) )
_import_structure["models.owlvit"].extend( _import_structure["models.owlvit"].extend(
...@@ -4408,6 +4409,7 @@ if TYPE_CHECKING: ...@@ -4408,6 +4409,7 @@ if TYPE_CHECKING:
from .models.opt import ( from .models.opt import (
OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OPTForCausalLM, OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification, OPTForSequenceClassification,
OPTModel, OPTModel,
OPTPreTrainedModel, OPTPreTrainedModel,
......
...@@ -611,6 +611,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -611,6 +611,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mvp", "MvpForQuestionAnswering"), ("mvp", "MvpForQuestionAnswering"),
("nezha", "NezhaForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"),
("nystromformer", "NystromformerForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"),
("opt", "OPTForQuestionAnswering"),
("qdqbert", "QDQBertForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"), ("reformer", "ReformerForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"),
......
...@@ -41,6 +41,7 @@ else: ...@@ -41,6 +41,7 @@ else:
"OPTModel", "OPTModel",
"OPTPreTrainedModel", "OPTPreTrainedModel",
"OPTForSequenceClassification", "OPTForSequenceClassification",
"OPTForQuestionAnswering",
] ]
try: try:
...@@ -76,6 +77,7 @@ if TYPE_CHECKING: ...@@ -76,6 +77,7 @@ if TYPE_CHECKING:
from .modeling_opt import ( from .modeling_opt import (
OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OPTForCausalLM, OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification, OPTForSequenceClassification,
OPTModel, OPTModel,
OPTPreTrainedModel, OPTPreTrainedModel,
......
...@@ -22,7 +22,12 @@ from torch import nn ...@@ -22,7 +22,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -48,6 +53,11 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" ...@@ -48,6 +53,11 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
_SEQ_CLASS_EXPECTED_LOSS = 1.71 _SEQ_CLASS_EXPECTED_LOSS = 1.71
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
# QuestionAnswering docstring
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/opt-125m", "facebook/opt-125m",
...@@ -1109,3 +1119,112 @@ class OPTForSequenceClassification(OPTPreTrainedModel): ...@@ -1109,3 +1119,112 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value self.model.decoder.embed_tokens = value
@add_start_docstrings(
"""
The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
OPT_START_DOCSTRING,
)
class OPTForQuestionAnswering(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: OPTConfig):
super().__init__(config)
self.model = OPTModel(config)
self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
qa_target_start_index=_QA_TARGET_START_INDEX,
qa_target_end_index=_QA_TARGET_END_INDEX,
expected_output=_QA_EXPECTED_OUTPUT,
expected_loss=_QA_EXPECTED_LOSS,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
...@@ -3714,6 +3714,13 @@ class OPTForCausalLM(metaclass=DummyObject): ...@@ -3714,6 +3714,13 @@ class OPTForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class OPTForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OPTForSequenceClassification(metaclass=DummyObject): class OPTForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -32,7 +32,13 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -32,7 +32,13 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel from transformers import (
GPT2Tokenizer,
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
def prepare_opt_inputs_dict( def prepare_opt_inputs_dict(
...@@ -178,7 +184,11 @@ class OPTModelTester: ...@@ -178,7 +184,11 @@ class OPTModelTester:
@require_torch @require_torch
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else () all_model_classes = (
(OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False is_encoder_decoder = False
fx_compatible = True fx_compatible = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment