"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "719c26c3c96ba84201b529ae3b5ff7590dee27df"
Unverified Commit 59824318 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Add GPT2ForSequenceClassification based on DialogRPT (#7501)

* Add GPT2ForSequenceClassification based on DialogRPT

* Better documentation

* Code quality
parent 500be01c
...@@ -88,6 +88,13 @@ GPT2DoubleHeadsModel ...@@ -88,6 +88,13 @@ GPT2DoubleHeadsModel
:members: forward :members: forward
GPT2ForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPT2ForSequenceClassification
:members: forward
TFGPT2Model TFGPT2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -375,6 +375,7 @@ if is_torch_available(): ...@@ -375,6 +375,7 @@ if is_torch_available():
from .modeling_gpt2 import ( from .modeling_gpt2 import (
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2DoubleHeadsModel, GPT2DoubleHeadsModel,
GPT2ForSequenceClassification,
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Model, GPT2Model,
GPT2PreTrainedModel, GPT2PreTrainedModel,
......
...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import ACT2FN from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
...@@ -33,7 +33,7 @@ from .file_utils import ( ...@@ -33,7 +33,7 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from .modeling_utils import ( from .modeling_utils import (
Conv1D, Conv1D,
PreTrainedModel, PreTrainedModel,
...@@ -946,3 +946,121 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -946,3 +946,121 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""The GPT2 Model transformer with a sequence classification head on top
(linear layer).
:class:`~transformers.GPT2ForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token.
If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token
in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch.
Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it
does the same (take the last value in each row of the batch).
""",
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPT2Model(config)
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/dialogrpt",
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -217,6 +217,42 @@ class CausalLMOutputWithPast(ModelOutput): ...@@ -217,6 +217,42 @@ class CausalLMOutputWithPast(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class MaskedLMOutput(ModelOutput): class MaskedLMOutput(ModelOutput):
""" """
......
...@@ -30,6 +30,7 @@ if is_torch_available(): ...@@ -30,6 +30,7 @@ if is_torch_available():
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2Config, GPT2Config,
GPT2DoubleHeadsModel, GPT2DoubleHeadsModel,
GPT2ForSequenceClassification,
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Model, GPT2Model,
) )
...@@ -87,6 +88,7 @@ class GPT2ModelTester: ...@@ -87,6 +88,7 @@ class GPT2ModelTester:
self.scope = None self.scope = None
self.bos_token_id = vocab_size - 1 self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -126,6 +128,7 @@ class GPT2ModelTester: ...@@ -126,6 +128,7 @@ class GPT2ModelTester:
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
return_dict=True, return_dict=True,
gradient_checkpointing=gradient_checkpointing, gradient_checkpointing=gradient_checkpointing,
) )
...@@ -337,6 +340,17 @@ class GPT2ModelTester: ...@@ -337,6 +340,17 @@ class GPT2ModelTester:
) )
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
def create_and_check_gpt2_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPT2ForSequenceClassification(config)
model.to(torch_device)
model.eval()
print(config.num_labels, sequence_labels.size())
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -364,10 +378,12 @@ class GPT2ModelTester: ...@@ -364,10 +378,12 @@ class GPT2ModelTester:
@require_torch @require_torch
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_model_classes = (
all_generative_model_classes = ( (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
(GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () if is_torch_available()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly else ()
)
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
...@@ -401,6 +417,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -401,6 +417,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
def test_gpt2_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self): def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment