Unverified Commit f867000f authored by as-stevens's avatar as-stevens Committed by GitHub
Browse files

[Reformer classification head] Implement the reformer model classification...


[Reformer classification head] Implement the reformer model classification head for text classification (#5198)

* Reformer model head classification implementation for text classification

* Reformat the reformer model classification code

* PR review comments, and test case implementation for reformer for classification head changes

* CI/CD reformer for classification head test import error fix

* CI/CD test case implementation  added ReformerForSequenceClassification to all_model_classes

* Code formatting- fixed

* Normal test cases added for reformer classification head

* Fix test cases implementation for the reformer classification head

* removed token_type_id parameter from the reformer classification head

* fixed the test case for reformer classification head

* merge conflict with master fixed

* merge conflict, changed reformer classification to accept the choice_label parameter added in latest code

* refactored the the reformer classification head test code

* reformer classification head, common transform test cases fixed

* final set of the review comment, rearranging the reformer classes and docstring add to classification forward method

* fixed the compilation error and text case fix for reformer classification head

* Apply suggestions from code review

Remove unnecessary dup
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f0bda06f
......@@ -378,6 +378,7 @@ if is_torch_available():
ReformerModel,
ReformerForMaskedLM,
ReformerModelWithLMHead,
ReformerForSequenceClassification,
ReformerForQuestionAnswering,
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
......
......@@ -25,7 +25,7 @@ import numpy as np
import torch
from torch import nn
from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_fast, gelu_new, swish
from .configuration_reformer import ReformerConfig
......@@ -36,7 +36,13 @@ from .file_utils import (
add_start_docstrings,
add_start_docstrings_to_callable,
)
from .modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput
from .modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
)
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
......@@ -1858,6 +1864,108 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
)
@add_start_docstrings(
"""Reformer Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
REFORMER_START_DOCSTRING,
)
class ReformerForSequenceClassification(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.reformer = ReformerModel(config)
self.classifier = ReformerClassificationHead(config)
if config.is_decoder is True:
logger.warning("You might want to disable causal masking for sequence classification")
self.init_weights()
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs = self.reformer(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
class ReformerClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
@add_start_docstrings(
"""Reformer Model with a span classification head on top for
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on
......
......@@ -28,6 +28,7 @@ if is_torch_available():
ReformerForMaskedLM,
ReformerModel,
ReformerModelWithLMHead,
ReformerForSequenceClassification,
ReformerTokenizer,
ReformerLayer,
ReformerForQuestionAnswering,
......@@ -77,6 +78,7 @@ class ReformerModelTester:
eos_token_id=None,
scope=None,
hash_seed=None,
num_labels=None,
):
self.parent = parent
self.batch_size = batch_size
......@@ -124,6 +126,7 @@ class ReformerModelTester:
self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0)
self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length
self.chunk_length = attn_chunk_length
self.num_labels = num_labels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......@@ -443,6 +446,22 @@ class ReformerModelTester:
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
def create_and_check_reformer_for_sequence_classification(
self, config, input_ids, input_mask, choice_labels, is_decoder
):
config.is_decoder = is_decoder
sequence_labels = ids_tensor([self.batch_size], config.num_labels)
model = ReformerForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
class ReformerTesterMixin:
"""
......@@ -510,11 +529,17 @@ class ReformerTesterMixin:
# Opt-out of this test.
pass
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_sequence_classification(*config_and_inputs, is_decoder=False)
@require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
......@@ -554,6 +579,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
"num_labels": 2,
}
def setUp(self):
......@@ -571,7 +597,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
......@@ -613,6 +641,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"eos_token_id": 2,
"scope": None,
"hash_seed": 0,
"num_labels": 2,
}
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment