"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9611c2d0aae7a1a667a3eecaa92756fea1073f20"
Unverified Commit f867000f authored by as-stevens's avatar as-stevens Committed by GitHub
Browse files

[Reformer classification head] Implement the reformer model classification...


[Reformer classification head] Implement the reformer model classification head for text classification (#5198)

* Reformer model head classification implementation for text classification

* Reformat the reformer model classification code

* PR review comments, and test case implementation for reformer for classification head changes

* CI/CD reformer for classification head test import error fix

* CI/CD test case implementation  added ReformerForSequenceClassification to all_model_classes

* Code formatting- fixed

* Normal test cases added for reformer classification head

* Fix test cases implementation for the reformer classification head

* removed token_type_id parameter from the reformer classification head

* fixed the test case for reformer classification head

* merge conflict with master fixed

* merge conflict, changed reformer classification to accept the choice_label parameter added in latest code

* refactored the the reformer classification head test code

* reformer classification head, common transform test cases fixed

* final set of the review comment, rearranging the reformer classes and docstring add to classification forward method

* fixed the compilation error and text case fix for reformer classification head

* Apply suggestions from code review

Remove unnecessary dup
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f0bda06f
...@@ -378,6 +378,7 @@ if is_torch_available(): ...@@ -378,6 +378,7 @@ if is_torch_available():
ReformerModel, ReformerModel,
ReformerForMaskedLM, ReformerForMaskedLM,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerForSequenceClassification,
ReformerForQuestionAnswering, ReformerForQuestionAnswering,
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
......
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.autograd.function import Function from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_fast, gelu_new, swish from .activations import gelu, gelu_fast, gelu_new, swish
from .configuration_reformer import ReformerConfig from .configuration_reformer import ReformerConfig
...@@ -36,7 +36,13 @@ from .file_utils import ( ...@@ -36,7 +36,13 @@ from .file_utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
) )
from .modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput from .modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
)
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
...@@ -1858,6 +1864,108 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -1858,6 +1864,108 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
) )
@add_start_docstrings(
"""Reformer Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
REFORMER_START_DOCSTRING,
)
class ReformerForSequenceClassification(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.reformer = ReformerModel(config)
self.classifier = ReformerClassificationHead(config)
if config.is_decoder is True:
logger.warning("You might want to disable causal masking for sequence classification")
self.init_weights()
def tie_weights(self):
# word embeddings are not tied in Reformer
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs = self.reformer(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
class ReformerClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
@add_start_docstrings( @add_start_docstrings(
"""Reformer Model with a span classification head on top for """Reformer Model with a span classification head on top for
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
ReformerForMaskedLM, ReformerForMaskedLM,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerForSequenceClassification,
ReformerTokenizer, ReformerTokenizer,
ReformerLayer, ReformerLayer,
ReformerForQuestionAnswering, ReformerForQuestionAnswering,
...@@ -77,6 +78,7 @@ class ReformerModelTester: ...@@ -77,6 +78,7 @@ class ReformerModelTester:
eos_token_id=None, eos_token_id=None,
scope=None, scope=None,
hash_seed=None, hash_seed=None,
num_labels=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -124,6 +126,7 @@ class ReformerModelTester: ...@@ -124,6 +126,7 @@ class ReformerModelTester:
self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0) self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0)
self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length
self.chunk_length = attn_chunk_length self.chunk_length = attn_chunk_length
self.num_labels = num_labels
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -443,6 +446,22 @@ class ReformerModelTester: ...@@ -443,6 +446,22 @@ class ReformerModelTester:
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
def create_and_check_reformer_for_sequence_classification(
self, config, input_ids, input_mask, choice_labels, is_decoder
):
config.is_decoder = is_decoder
sequence_labels = ids_tensor([self.batch_size], config.num_labels)
model = ReformerForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
class ReformerTesterMixin: class ReformerTesterMixin:
""" """
...@@ -510,11 +529,17 @@ class ReformerTesterMixin: ...@@ -510,11 +529,17 @@ class ReformerTesterMixin:
# Opt-out of this test. # Opt-out of this test.
pass pass
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_sequence_classification(*config_and_inputs, is_decoder=False)
@require_torch @require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else () (ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
else ()
) )
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False
...@@ -554,6 +579,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest ...@@ -554,6 +579,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
"eos_token_id": 2, "eos_token_id": 2,
"scope": None, "scope": None,
"hash_seed": 0, "hash_seed": 0,
"num_labels": 2,
} }
def setUp(self): def setUp(self):
...@@ -571,7 +597,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest ...@@ -571,7 +597,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
@require_torch @require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else () (ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
else ()
) )
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False
...@@ -613,6 +641,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -613,6 +641,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"eos_token_id": 2, "eos_token_id": 2,
"scope": None, "scope": None,
"hash_seed": 0, "hash_seed": 0,
"num_labels": 2,
} }
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment