Unverified Commit 4a9e502a authored by elk-cloner's avatar elk-cloner Committed by GitHub
Browse files

Ctrl for sequence classification (#8812)

* add CTRLForSequenceClassification

* pass local test

* merge with master

* fix modeling test for sequence classification

* fix deco

* fix assert
parent 7f34d757
...@@ -65,6 +65,13 @@ CTRLLMHeadModel ...@@ -65,6 +65,13 @@ CTRLLMHeadModel
:members: forward :members: forward
CTRLForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.CTRLForSequenceClassification
:members: forward
TFCTRLModel TFCTRLModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -391,7 +391,13 @@ if is_torch_available(): ...@@ -391,7 +391,13 @@ if is_torch_available():
CamembertForTokenClassification, CamembertForTokenClassification,
CamembertModel, CamembertModel,
) )
from .models.ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel from .models.ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
CTRLPreTrainedModel,
)
from .models.deberta import ( from .models.deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForSequenceClassification, DebertaForSequenceClassification,
......
...@@ -60,7 +60,7 @@ from ..camembert.modeling_camembert import ( ...@@ -60,7 +60,7 @@ from ..camembert.modeling_camembert import (
CamembertForTokenClassification, CamembertForTokenClassification,
CamembertModel, CamembertModel,
) )
from ..ctrl.modeling_ctrl import CTRLLMHeadModel, CTRLModel from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel
from ..distilbert.modeling_distilbert import ( from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM, DistilBertForMaskedLM,
...@@ -415,6 +415,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -415,6 +415,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(GPT2Config, GPT2ForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification),
] ]
) )
......
...@@ -8,7 +8,13 @@ from .tokenization_ctrl import CTRLTokenizer ...@@ -8,7 +8,13 @@ from .tokenization_ctrl import CTRLTokenizer
if is_torch_available(): if is_torch_available():
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel from .modeling_ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
CTRLPreTrainedModel,
)
if is_tf_available(): if is_tf_available():
from .modeling_tf_ctrl import ( from .modeling_tf_ctrl import (
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
...@@ -571,3 +571,117 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -571,3 +571,117 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The CTRL Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.CTRLForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the
position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that
is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each
row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of
:obj:`input_ids`, it does the same (take the last value in each row of the batch).
""",
CTRL_START_DOCSTRING,
)
class CTRLForSequenceClassification(CTRLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = CTRLModel(config)
self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.classifier(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -634,6 +634,15 @@ class CamembertModel: ...@@ -634,6 +634,15 @@ class CamembertModel:
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
class CTRLForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class CTRLLMHeadModel: class CTRLLMHeadModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention ...@@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLConfig, CTRLLMHeadModel, CTRLModel from transformers import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLConfig,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
)
class CTRLModelTester: class CTRLModelTester:
...@@ -57,6 +63,7 @@ class CTRLModelTester: ...@@ -57,6 +63,7 @@ class CTRLModelTester:
self.num_labels = 3 self.num_labels = 3
self.num_choices = 4 self.num_choices = 4
self.scope = None self.scope = None
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -94,6 +101,7 @@ class CTRLModelTester: ...@@ -94,6 +101,7 @@ class CTRLModelTester:
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -149,11 +157,20 @@ class CTRLModelTester: ...@@ -149,11 +157,20 @@ class CTRLModelTester:
return config, inputs_dict return config, inputs_dict
def create_and_check_ctrl_for_sequence_classification(self, config, input_ids, head_mask, token_type_ids, *args):
config.num_labels = self.num_labels
model = CTRLForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
@require_torch @require_torch
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
test_pruning = True test_pruning = True
test_torchscript = False test_torchscript = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment