Unverified Commit 614e191c authored by peter-sk's avatar peter-sk Committed by GitHub
Browse files

added GPTNeoXForTokenClassification (#23002)



* initial commit

* added GPTNeoXForTokenClassification

* typo

* doc
fixed extra comma that turned into a tuple

* unifying variable names
fixing forward call

* classifier_dropout is in config
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarProf. Peter Schneider-Kamp <jps@ordbogen.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 1933231a
...@@ -82,4 +82,9 @@ The `generate()` method can be used to generate text using GPT Neo model. ...@@ -82,4 +82,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
## GPTNeoXForSequenceClassification ## GPTNeoXForSequenceClassification
[[autodoc]] GPTNeoXForSequenceClassification [[autodoc]] GPTNeoXForSequenceClassification
- forward - forward
\ No newline at end of file
## GPTNeoXForTokenClassification
[[autodoc]] GPTNeoXForTokenClassification
- forward
...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -1697,6 +1697,7 @@ else: ...@@ -1697,6 +1697,7 @@ else:
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification", "GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
"GPTNeoXModel", "GPTNeoXModel",
"GPTNeoXPreTrainedModel", "GPTNeoXPreTrainedModel",
...@@ -5230,6 +5231,7 @@ if TYPE_CHECKING: ...@@ -5230,6 +5231,7 @@ if TYPE_CHECKING:
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification, GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXLayer, GPTNeoXLayer,
GPTNeoXModel, GPTNeoXModel,
GPTNeoXPreTrainedModel, GPTNeoXPreTrainedModel,
......
...@@ -814,6 +814,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -814,6 +814,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt-sw3", "GPT2ForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neox", "GPTNeoXForTokenClassification"),
("ibert", "IBertForTokenClassification"), ("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
......
...@@ -37,6 +37,7 @@ else: ...@@ -37,6 +37,7 @@ else:
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification", "GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
"GPTNeoXModel", "GPTNeoXModel",
"GPTNeoXPreTrainedModel", "GPTNeoXPreTrainedModel",
...@@ -64,6 +65,7 @@ if TYPE_CHECKING: ...@@ -64,6 +65,7 @@ if TYPE_CHECKING:
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification, GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXLayer, GPTNeoXLayer,
GPTNeoXModel, GPTNeoXModel,
GPTNeoXPreTrainedModel, GPTNeoXPreTrainedModel,
......
...@@ -56,6 +56,10 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -56,6 +56,10 @@ class GPTNeoXConfig(PretrainedConfig):
percentage of hidden dimensions to allocate to rotary embeddings percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000) rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency base for computing rotary embeddings frequency
classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
The dropout ratio for the hidden layer.
max_position_embeddings (`int`, *optional*, defaults to 2048): max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
...@@ -95,6 +99,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -95,6 +99,7 @@ class GPTNeoXConfig(PretrainedConfig):
hidden_act="gelu", hidden_act="gelu",
rotary_pct=0.25, rotary_pct=0.25,
rotary_emb_base=10000, rotary_emb_base=10000,
classifier_dropout=0.1,
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
...@@ -115,6 +120,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -115,6 +120,7 @@ class GPTNeoXConfig(PretrainedConfig):
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.classifier_dropout = classifier_dropout
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
......
...@@ -28,7 +28,12 @@ from ...file_utils import ( ...@@ -28,7 +28,12 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig from .configuration_gpt_neox import GPTNeoXConfig
...@@ -873,3 +878,80 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): ...@@ -873,3 +878,80 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.gpt_neox = GPTNeoXModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint="LarsJonasson/pythia-410m-deduped-sft-swedish",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=0.25,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.gpt_neox(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -3308,6 +3308,13 @@ class GPTNeoXForSequenceClassification(metaclass=DummyObject): ...@@ -3308,6 +3308,13 @@ class GPTNeoXForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GPTNeoXForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoXLayer(metaclass=DummyObject): class GPTNeoXLayer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -29,7 +29,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -29,7 +29,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel from transformers import (
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
class GPTNeoXModelTester: class GPTNeoXModelTester:
...@@ -153,6 +158,14 @@ class GPTNeoXModelTester: ...@@ -153,6 +158,14 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True config.is_decoder = True
model = GPTNeoXForCausalLM(config=config) model = GPTNeoXForCausalLM(config=config)
...@@ -200,13 +213,16 @@ class GPTNeoXModelTester: ...@@ -200,13 +213,16 @@ class GPTNeoXModelTester:
@require_torch @require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else () (GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": GPTNeoXModel, "feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification, "text-classification": GPTNeoXForSequenceClassification,
"token-classification": GPTNeoXForTokenClassification,
"text-generation": GPTNeoXForCausalLM, "text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification, "zero-shot": GPTNeoXForSequenceClassification,
} }
...@@ -253,6 +269,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -253,6 +269,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_model_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented") @unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment