"docs/source/ko/tasks/token_classification.mdx" did not exist on "2b9513fdabbcfd3ca5d7003a955be633a2f365fc"
Unverified Commit e1205e47 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

Added Sequence Classification class in GPTNeo (#11906)

* seq classification changes

* fix tests
parent 80d712fa
datasets @ d95b95f8
Subproject commit d95b95f8cf3cb0cff5f77a675139b584dcfcf719
...@@ -65,3 +65,9 @@ GPTNeoForCausalLM ...@@ -65,3 +65,9 @@ GPTNeoForCausalLM
.. autoclass:: transformers.GPTNeoForCausalLM .. autoclass:: transformers.GPTNeoForCausalLM
:members: forward :members: forward
GPTNeoForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTNeoForSequenceClassification
:members: forward
...@@ -746,6 +746,7 @@ if is_torch_available(): ...@@ -746,6 +746,7 @@ if is_torch_available():
[ [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM", "GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel", "GPTNeoModel",
"GPTNeoPreTrainedModel", "GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo", "load_tf_weights_in_gpt_neo",
...@@ -2129,6 +2130,7 @@ if TYPE_CHECKING: ...@@ -2129,6 +2130,7 @@ if TYPE_CHECKING:
from .models.gpt_neo import ( from .models.gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel, GPTNeoModel,
GPTNeoPreTrainedModel, GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo, load_tf_weights_in_gpt_neo,
......
...@@ -145,7 +145,7 @@ from ..funnel.modeling_funnel import ( ...@@ -145,7 +145,7 @@ from ..funnel.modeling_funnel import (
FunnelModel, FunnelModel,
) )
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
from ..ibert.modeling_ibert import ( from ..ibert.modeling_ibert import (
IBertForMaskedLM, IBertForMaskedLM,
IBertForMultipleChoice, IBertForMultipleChoice,
...@@ -632,6 +632,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -632,6 +632,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(DebertaConfig, DebertaForSequenceClassification), (DebertaConfig, DebertaForSequenceClassification),
(DebertaV2Config, DebertaV2ForSequenceClassification), (DebertaV2Config, DebertaV2ForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification),
(GPTNeoConfig, GPTNeoForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification), (CTRLConfig, CTRLForSequenceClassification),
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["modeling_gpt_neo"] = [ _import_structure["modeling_gpt_neo"] = [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM", "GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel", "GPTNeoModel",
"GPTNeoPreTrainedModel", "GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo", "load_tf_weights_in_gpt_neo",
...@@ -41,6 +42,7 @@ if TYPE_CHECKING: ...@@ -41,6 +42,7 @@ if TYPE_CHECKING:
from .modeling_gpt_neo import ( from .modeling_gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel, GPTNeoModel,
GPTNeoPreTrainedModel, GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo, load_tf_weights_in_gpt_neo,
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -31,6 +31,7 @@ from ...modeling_outputs import ( ...@@ -31,6 +31,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast, CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
...@@ -1027,3 +1028,120 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -1027,3 +1028,120 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past
) )
@add_start_docstrings(
"""
The GPTNeo Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.GPTNeoForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTNeoModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -1603,6 +1603,15 @@ class GPTNeoForCausalLM: ...@@ -1603,6 +1603,15 @@ class GPTNeoForCausalLM:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GPTNeoForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoModel: class GPTNeoModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -361,7 +361,6 @@ class GPT2ModelTester: ...@@ -361,7 +361,6 @@ class GPT2ModelTester:
model = GPT2ForSequenceClassification(config) model = GPT2ForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
print(config.num_labels, sequence_labels.size())
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
......
...@@ -34,6 +34,7 @@ if is_torch_available(): ...@@ -34,6 +34,7 @@ if is_torch_available():
GPT2Tokenizer, GPT2Tokenizer,
GPTNeoConfig, GPTNeoConfig,
GPTNeoForCausalLM, GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel, GPTNeoModel,
) )
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
...@@ -238,6 +239,16 @@ class GPTNeoModelTester: ...@@ -238,6 +239,16 @@ class GPTNeoModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt_neo_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTNeoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPTNeoForCausalLM(config) model = GPTNeoForCausalLM(config)
model.to(torch_device) model.to(torch_device)
...@@ -274,7 +285,9 @@ class GPTNeoModelTester: ...@@ -274,7 +285,9 @@ class GPTNeoModelTester:
@require_torch @require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else () all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_ready_model_classes = all_model_classes
test_missing_keys = False test_missing_keys = False
...@@ -305,6 +318,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -305,6 +318,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_gpt_neo_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self): def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment