Unverified Commit dcec4c43 authored by Raghavan's avatar Raghavan Committed by GitHub
Browse files

Adding OPTForSeqClassification class (#18123)

* Adding OPTForSeqClassification class

* Fix import issues

* Add documentation for optforseqclassification

* Remove checkout

* fix failing tests

* fix typo

* Fix code formatting

* Incorporating the PR feedbacks

* Incorporate PR Feedbacks

* Fix failing test and add new test for multi label setup

* Fix formatting issue

* Fix failing tests

* Fix formatting issues

* Fix failing tests

* Fix failing tests

* Fix failing tests

* Fix failing tests

* PR feedback
parent 0ed4d0df
...@@ -54,6 +54,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase ...@@ -54,6 +54,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase
[[autodoc]] TFOPTForCausalLM [[autodoc]] TFOPTForCausalLM
- call - call
## OPTForSequenceClassification
[[autodoc]] OPTForSequenceClassification
- forward
## FlaxOPTModel ## FlaxOPTModel
[[autodoc]] FlaxOPTModel [[autodoc]] FlaxOPTModel
......
...@@ -1504,6 +1504,7 @@ else: ...@@ -1504,6 +1504,7 @@ else:
"OPTForCausalLM", "OPTForCausalLM",
"OPTModel", "OPTModel",
"OPTPreTrainedModel", "OPTPreTrainedModel",
"OPTForSequenceClassification",
] ]
) )
_import_structure["models.pegasus"].extend( _import_structure["models.pegasus"].extend(
...@@ -4026,7 +4027,13 @@ if TYPE_CHECKING: ...@@ -4026,7 +4027,13 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel, OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
) )
from .models.opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel from .models.opt import (
OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OPTForCausalLM,
OPTForSequenceClassification,
OPTModel,
OPTPreTrainedModel,
)
from .models.pegasus import ( from .models.pegasus import (
PegasusForCausalLM, PegasusForCausalLM,
PegasusForConditionalGeneration, PegasusForConditionalGeneration,
......
...@@ -503,6 +503,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -503,6 +503,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("nezha", "NezhaForSequenceClassification"), ("nezha", "NezhaForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"), ("openai-gpt", "OpenAIGPTForSequenceClassification"),
("opt", "OPTForSequenceClassification"),
("perceiver", "PerceiverForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"), ("plbart", "PLBartForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"),
......
...@@ -40,6 +40,7 @@ else: ...@@ -40,6 +40,7 @@ else:
"OPTForCausalLM", "OPTForCausalLM",
"OPTModel", "OPTModel",
"OPTPreTrainedModel", "OPTPreTrainedModel",
"OPTForSequenceClassification",
] ]
try: try:
...@@ -72,7 +73,13 @@ if TYPE_CHECKING: ...@@ -72,7 +73,13 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
from .modeling_opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel from .modeling_opt import (
OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OPTForCausalLM,
OPTForSequenceClassification,
OPTModel,
OPTPreTrainedModel,
)
try: try:
if not is_tf_available(): if not is_tf_available():
......
...@@ -19,10 +19,10 @@ from typing import List, Optional, Tuple, Union ...@@ -19,10 +19,10 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -383,6 +383,7 @@ OPT_START_DOCSTRING = r""" ...@@ -383,6 +383,7 @@ OPT_START_DOCSTRING = r"""
OPT_START_DOCSTRING, OPT_START_DOCSTRING,
) )
class OPTPreTrainedModel(PreTrainedModel): class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig config_class = OPTConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
...@@ -729,7 +730,6 @@ class OPTModel(OPTPreTrainedModel): ...@@ -729,7 +730,6 @@ class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig): def __init__(self, config: OPTConfig):
super().__init__(config) super().__init__(config)
self.decoder = OPTDecoder(config) self.decoder = OPTDecoder(config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -976,3 +976,133 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -976,3 +976,133 @@ class OPTForCausalLM(OPTPreTrainedModel):
for layer_past in past: for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
@add_start_docstrings(
"""
The OPT Model transformer with a sequence classification head on top (linear layer).
[`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
OPT_START_DOCSTRING,
)
class OPTForSequenceClassification(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: OPTConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OPTModel(config)
self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
expected_output="'LABEL_0'",
expected_loss=5.28,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
...@@ -3438,6 +3438,13 @@ class OPTForCausalLM(metaclass=DummyObject): ...@@ -3438,6 +3438,13 @@ class OPTForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class OPTForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OPTModel(metaclass=DummyObject): class OPTModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GPT2Tokenizer, OPTForCausalLM, OPTModel from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel
def prepare_opt_inputs_dict( def prepare_opt_inputs_dict(
...@@ -74,7 +74,9 @@ class OPTModelTester: ...@@ -74,7 +74,9 @@ class OPTModelTester:
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
embed_dim=16, embed_dim=16,
num_labels=3,
word_embed_proj_dim=16, word_embed_proj_dim=16,
type_sequence_label_size=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -94,11 +96,12 @@ class OPTModelTester: ...@@ -94,11 +96,12 @@ class OPTModelTester:
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_labels = num_labels
self.type_sequence_label_size = type_sequence_label_size
self.word_embed_proj_dim = word_embed_proj_dim self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False self.is_encoder_decoder = False
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3, 3,
) )
...@@ -175,7 +178,7 @@ class OPTModelTester: ...@@ -175,7 +178,7 @@ class OPTModelTester:
@require_torch @require_torch
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else () all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False is_encoder_decoder = False
fx_compatible = True fx_compatible = True
...@@ -242,6 +245,33 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -242,6 +245,33 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model.generate(input_ids, attention_mask=attention_mask) model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_opt_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = OPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_opt_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = OPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment