Unverified Commit 83b38fbe authored by peter-sk's avatar peter-sk Committed by GitHub
Browse files

GPTNeoXForQuestionAnswering (#23059)



* first draft - gives index error in question_answering.py

* maturing

* no labels

* pipeline should know about QA

* fixing checks

* formatting

* fixed docstring

* initial commit

* formatting

* adding the class to many places

* towards less unhappy checks

* nearly there

* and gpt neox for qa

* use right model

* forgot this one

* base_model_prefix is "gpt_neox" for GPTNeoX* models

* unnecessary stuff

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* format

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* removed gpt2 stuff

---------
Co-authored-by: default avatarProf. Peter Schneider-Kamp <jps@ordbogen.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 510ad0a8
...@@ -79,6 +79,11 @@ The `generate()` method can be used to generate text using GPT Neo model. ...@@ -79,6 +79,11 @@ The `generate()` method can be used to generate text using GPT Neo model.
[[autodoc]] GPTNeoXForCausalLM [[autodoc]] GPTNeoXForCausalLM
- forward - forward
## GPTNeoXForQuestionAnswering
[[autodoc]] GPTNeoXForQuestionAnswering
- forward
## GPTNeoXForSequenceClassification ## GPTNeoXForSequenceClassification
[[autodoc]] GPTNeoXForSequenceClassification [[autodoc]] GPTNeoXForSequenceClassification
......
...@@ -31,7 +31,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -31,7 +31,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -1702,6 +1702,7 @@ else: ...@@ -1702,6 +1702,7 @@ else:
[ [
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForQuestionAnswering",
"GPTNeoXForSequenceClassification", "GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification", "GPTNeoXForTokenClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
...@@ -5245,6 +5246,7 @@ if TYPE_CHECKING: ...@@ -5245,6 +5246,7 @@ if TYPE_CHECKING:
from .models.gpt_neox import ( from .models.gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification, GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification, GPTNeoXForTokenClassification,
GPTNeoXLayer, GPTNeoXLayer,
......
...@@ -737,6 +737,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -737,6 +737,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForQuestionAnswering"), ("funnel", "FunnelForQuestionAnswering"),
("gpt2", "GPT2ForQuestionAnswering"), ("gpt2", "GPT2ForQuestionAnswering"),
("gpt_neo", "GPTNeoForQuestionAnswering"), ("gpt_neo", "GPTNeoForQuestionAnswering"),
("gpt_neox", "GPTNeoXForQuestionAnswering"),
("gptj", "GPTJForQuestionAnswering"), ("gptj", "GPTJForQuestionAnswering"),
("ibert", "IBertForQuestionAnswering"), ("ibert", "IBertForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
......
...@@ -52,7 +52,6 @@ from .configuration_gpt2 import GPT2Config ...@@ -52,7 +52,6 @@ from .configuration_gpt2 import GPT2Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "gpt2" _CHECKPOINT_FOR_DOC = "gpt2"
_REAL_CHECKPOINT_FOR_DOC = "gpt2"
_CONFIG_FOR_DOC = "GPT2Config" _CONFIG_FOR_DOC = "GPT2Config"
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1619,7 +1618,7 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel): ...@@ -1619,7 +1618,7 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput, output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, real_checkpoint=_CHECKPOINT_FOR_DOC,
) )
def forward( def forward(
self, self,
......
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
_import_structure["modeling_gpt_neox"] = [ _import_structure["modeling_gpt_neox"] = [
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForQuestionAnswering",
"GPTNeoXForSequenceClassification", "GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification", "GPTNeoXForTokenClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
...@@ -64,6 +65,7 @@ if TYPE_CHECKING: ...@@ -64,6 +65,7 @@ if TYPE_CHECKING:
from .modeling_gpt_neox import ( from .modeling_gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification, GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification, GPTNeoXForTokenClassification,
GPTNeoXLayer, GPTNeoXLayer,
......
...@@ -31,6 +31,7 @@ from ...file_utils import ( ...@@ -31,6 +31,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
...@@ -955,3 +956,103 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel): ...@@ -955,3 +956,103 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
The GPT-NeoX Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.gpt_neox = GPTNeoXModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -3336,6 +3336,13 @@ class GPTNeoXForCausalLM(metaclass=DummyObject): ...@@ -3336,6 +3336,13 @@ class GPTNeoXForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GPTNeoXForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoXForSequenceClassification(metaclass=DummyObject): class GPTNeoXForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
from transformers import ( from transformers import (
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification, GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification, GPTNeoXForTokenClassification,
GPTNeoXModel, GPTNeoXModel,
...@@ -149,6 +150,15 @@ class GPTNeoXModelTester: ...@@ -149,6 +150,15 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_question_answering(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForQuestionAnswering(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels): def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = GPTNeoXForSequenceClassification(config) model = GPTNeoXForSequenceClassification(config)
...@@ -213,7 +223,13 @@ class GPTNeoXModelTester: ...@@ -213,7 +223,13 @@ class GPTNeoXModelTester:
@require_torch @require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification) (
GPTNeoXModel,
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
)
if is_torch_available() if is_torch_available()
else () else ()
) )
...@@ -221,6 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -221,6 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": GPTNeoXModel, "feature-extraction": GPTNeoXModel,
"question-answering": GPTNeoXForQuestionAnswering,
"text-classification": GPTNeoXForSequenceClassification, "text-classification": GPTNeoXForSequenceClassification,
"token-classification": GPTNeoXForTokenClassification, "token-classification": GPTNeoXForTokenClassification,
"text-generation": GPTNeoXForCausalLM, "text-generation": GPTNeoXForCausalLM,
...@@ -265,6 +282,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -265,6 +282,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_model_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_model_for_sequence_classification(self): def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment