Unverified Commit f6b44e61 authored by sandip's avatar sandip Committed by GitHub
Browse files

Transfoxl seq classification (#8868)

* Transfoxl sequence classification

* Transfoxl sequence classification
parent 24f0c2fe
...@@ -75,6 +75,11 @@ TransfoXLLMHeadModel ...@@ -75,6 +75,11 @@ TransfoXLLMHeadModel
.. autoclass:: transformers.TransfoXLLMHeadModel .. autoclass:: transformers.TransfoXLLMHeadModel
:members: forward :members: forward
TransfoXLForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TransfoXLForSequenceClassification
:members: forward
TFTransfoXLModel TFTransfoXLModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -578,6 +578,7 @@ if is_torch_available(): ...@@ -578,6 +578,7 @@ if is_torch_available():
from .models.transfo_xl import ( from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding, AdaptiveEmbedding,
TransfoXLForSequenceClassification,
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
TransfoXLModel, TransfoXLModel,
TransfoXLPreTrainedModel, TransfoXLPreTrainedModel,
......
...@@ -157,7 +157,7 @@ from ..squeezebert.modeling_squeezebert import ( ...@@ -157,7 +157,7 @@ from ..squeezebert.modeling_squeezebert import (
SqueezeBertModel, SqueezeBertModel,
) )
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from ..transfo_xl.modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
...@@ -416,6 +416,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -416,6 +416,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification), (CTRLConfig, CTRLForSequenceClassification),
(TransfoXLConfig, TransfoXLForSequenceClassification),
] ]
) )
......
...@@ -11,6 +11,7 @@ if is_torch_available(): ...@@ -11,6 +11,7 @@ if is_torch_available():
from .modeling_transfo_xl import ( from .modeling_transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding, AdaptiveEmbedding,
TransfoXLForSequenceClassification,
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
TransfoXLModel, TransfoXLModel,
TransfoXLPreTrainedModel, TransfoXLPreTrainedModel,
......
...@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple ...@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
...@@ -632,6 +633,40 @@ class TransfoXLModelOutput(ModelOutput): ...@@ -632,6 +633,40 @@ class TransfoXLModelOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class TransfoXLSequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see :obj:`mems`
input) to speed up sequential decoding. The token ids which have their past given to this model should not
be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
mems: List[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class TransfoXLLMHeadModelOutput(ModelOutput): class TransfoXLLMHeadModelOutput(ModelOutput):
""" """
...@@ -1101,3 +1136,110 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1101,3 +1136,110 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.cutoffs = new_cutoffs self.crit.cutoffs = new_cutoffs
self.crit.cutoff_ends = [0] + new_cutoffs self.crit.cutoff_ends = [0] + new_cutoffs
self.crit.n_token = new_num_tokens self.crit.n_token = new_num_tokens
@add_start_docstrings(
"""
The Transformer-XL Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.TransfoXLForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
TRANSFO_XL_START_DOCSTRING,
)
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = TransfoXLModel(config)
self.score = nn.Linear(config.d_embed, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="transfo-xl-wt103",
output_type=TransfoXLSequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return TransfoXLSequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -1785,6 +1785,15 @@ class AdaptiveEmbedding: ...@@ -1785,6 +1785,15 @@ class AdaptiveEmbedding:
requires_pytorch(self) requires_pytorch(self)
class TransfoXLForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class TransfoXLLMHeadModel: class TransfoXLLMHeadModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLModel from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -56,6 +56,8 @@ class TransfoXLModelTester: ...@@ -56,6 +56,8 @@ class TransfoXLModelTester:
self.scope = None self.scope = None
self.seed = 1 self.seed = 1
self.eos_token_id = 0 self.eos_token_id = 0
self.num_labels = 3
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -78,6 +80,7 @@ class TransfoXLModelTester: ...@@ -78,6 +80,7 @@ class TransfoXLModelTester:
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
) )
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)
...@@ -148,6 +151,14 @@ class TransfoXLModelTester: ...@@ -148,6 +151,14 @@ class TransfoXLModelTester:
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
config.num_labels = self.num_labels
model = TransfoXLForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids_1)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
...@@ -157,7 +168,9 @@ class TransfoXLModelTester: ...@@ -157,7 +168,9 @@ class TransfoXLModelTester:
@require_torch @require_torch
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_model_classes = (
(TransfoXLModel, TransfoXLLMHeadModel, TransfoXLForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
...@@ -204,6 +217,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -204,6 +217,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result) self.model_tester.check_transfo_xl_lm_head_output(output_result)
def test_transfo_xl_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# xlnet cannot keep gradients in attentions or hidden states # xlnet cannot keep gradients in attentions or hidden states
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment