Unverified Commit d65b14ed authored by peter-sk's avatar peter-sk Committed by GitHub
Browse files

added GPTNeoForTokenClassification (#22908)



* added GPTNeoForTokenClassification

* add to top-level init

* fixup

* test

* more fixup

* add to gpt_neo.mdx

* repo consistency

* dummy copy

* fix copies

* optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6

* merge with main made this superfluous

* added classifier_dropout

* remove legacy code

* removed fmt:on/off
removed expected_outputs

* doc style fix

* classifier_dropout is always in config

---------
Co-authored-by: default avatarProf. Peter Schneider-Kamp <jps@ordbogen.com>
parent 614e191c
...@@ -74,6 +74,11 @@ The `generate()` method can be used to generate text using GPT Neo model. ...@@ -74,6 +74,11 @@ The `generate()` method can be used to generate text using GPT Neo model.
[[autodoc]] GPTNeoForSequenceClassification [[autodoc]] GPTNeoForSequenceClassification
- forward - forward
## GPTNeoForTokenClassification
[[autodoc]] GPTNeoForTokenClassification
- forward
## FlaxGPTNeoModel ## FlaxGPTNeoModel
[[autodoc]] FlaxGPTNeoModel [[autodoc]] FlaxGPTNeoModel
......
...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -1687,6 +1687,7 @@ else: ...@@ -1687,6 +1687,7 @@ else:
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM", "GPTNeoForCausalLM",
"GPTNeoForSequenceClassification", "GPTNeoForSequenceClassification",
"GPTNeoForTokenClassification",
"GPTNeoModel", "GPTNeoModel",
"GPTNeoPreTrainedModel", "GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo", "load_tf_weights_in_gpt_neo",
...@@ -5223,6 +5224,7 @@ if TYPE_CHECKING: ...@@ -5223,6 +5224,7 @@ if TYPE_CHECKING:
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification, GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel, GPTNeoModel,
GPTNeoPreTrainedModel, GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo, load_tf_weights_in_gpt_neo,
......
...@@ -814,6 +814,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -814,6 +814,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt-sw3", "GPT2ForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neo", "GPTNeoForTokenClassification"),
("gpt_neox", "GPTNeoXForTokenClassification"), ("gpt_neox", "GPTNeoXForTokenClassification"),
("ibert", "IBertForTokenClassification"), ("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"),
......
...@@ -30,6 +30,7 @@ else: ...@@ -30,6 +30,7 @@ else:
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM", "GPTNeoForCausalLM",
"GPTNeoForSequenceClassification", "GPTNeoForSequenceClassification",
"GPTNeoForTokenClassification",
"GPTNeoModel", "GPTNeoModel",
"GPTNeoPreTrainedModel", "GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo", "load_tf_weights_in_gpt_neo",
...@@ -61,6 +62,7 @@ if TYPE_CHECKING: ...@@ -61,6 +62,7 @@ if TYPE_CHECKING:
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification, GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel, GPTNeoModel,
GPTNeoPreTrainedModel, GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo, load_tf_weights_in_gpt_neo,
......
...@@ -66,6 +66,10 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -66,6 +66,10 @@ class GPTNeoConfig(PretrainedConfig):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`].
The dropout ratio for the hidden layer.
max_position_embeddings (`int`, *optional*, defaults to 2048): max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
...@@ -111,6 +115,7 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -111,6 +115,7 @@ class GPTNeoConfig(PretrainedConfig):
resid_dropout=0.0, resid_dropout=0.0,
embed_dropout=0.0, embed_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
classifier_dropout=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
...@@ -129,6 +134,7 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -129,6 +134,7 @@ class GPTNeoConfig(PretrainedConfig):
self.resid_dropout = resid_dropout self.resid_dropout = resid_dropout
self.embed_dropout = embed_dropout self.embed_dropout = embed_dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.classifier_dropout = classifier_dropout
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_cache = use_cache self.use_cache = use_cache
......
...@@ -30,6 +30,7 @@ from ...modeling_outputs import ( ...@@ -30,6 +30,7 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast, CausalLMOutputWithPast,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
...@@ -926,3 +927,88 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -926,3 +927,88 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
GPT Neo model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTNeoModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint="EleutherAI/gpt-neo-125m",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=0.25,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -3273,6 +3273,13 @@ class GPTNeoForSequenceClassification(metaclass=DummyObject): ...@@ -3273,6 +3273,13 @@ class GPTNeoForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GPTNeoForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoModel(metaclass=DummyObject): class GPTNeoModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -35,6 +35,7 @@ if is_torch_available(): ...@@ -35,6 +35,7 @@ if is_torch_available():
GPT2Tokenizer, GPT2Tokenizer,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification, GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel, GPTNeoModel,
) )
...@@ -334,6 +335,16 @@ class GPTNeoModelTester: ...@@ -334,6 +335,16 @@ class GPTNeoModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_gpt_neo_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTNeoForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards( def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
): ):
...@@ -374,13 +385,16 @@ class GPTNeoModelTester: ...@@ -374,13 +385,16 @@ class GPTNeoModelTester:
@require_torch @require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": GPTNeoModel, "feature-extraction": GPTNeoModel,
"text-classification": GPTNeoForSequenceClassification, "text-classification": GPTNeoForSequenceClassification,
"token-classification": GPTNeoForTokenClassification,
"text-generation": GPTNeoForCausalLM, "text-generation": GPTNeoForCausalLM,
"zero-shot": GPTNeoForSequenceClassification, "zero-shot": GPTNeoForSequenceClassification,
} }
...@@ -428,6 +442,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -428,6 +442,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_token_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self): def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment