"src/nni_manager/vscode:/vscode.git/clone" did not exist on "058b58a6b9eaff708581b380b236a9cc64ef0b14"
Unverified Commit 4e10acb3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add more models to common tests (#4910)

parent 3b3619a3
......@@ -848,7 +848,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
......
......@@ -435,7 +435,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
sequence_output = discriminator_hidden_states[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + discriminator_hidden_states[2:] # add hidden states and attention if they are here
outputs = (logits,) + discriminator_hidden_states[1:] # add hidden states and attention if they are here
if labels is not None:
if self.num_labels == 1:
......
......@@ -797,6 +797,8 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config)
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
......@@ -861,6 +863,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
......@@ -919,7 +922,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
......@@ -1099,6 +1102,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
......@@ -1228,6 +1232,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask,
output_attentions=output_attentions,
)
pooled_output = outputs[1]
......
......@@ -300,6 +300,8 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
......@@ -618,7 +620,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......
......@@ -38,7 +38,13 @@ if is_torch_available():
class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
(
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
)
if is_torch_available()
else None
)
......
......@@ -39,7 +39,15 @@ if is_torch_available():
class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(ElectraModel, ElectraForMaskedLM, ElectraForTokenClassification,) if is_torch_available() else ()
(
ElectraModel,
ElectraForPreTraining,
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraForSequenceClassification,
)
if is_torch_available()
else ()
)
class ElectraModelTester(object):
......
......@@ -296,7 +296,19 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (LongformerModel, LongformerForMaskedLM,) if is_torch_available() else ()
all_model_classes = (
(
LongformerModel,
LongformerForMaskedLM,
# TODO: make tests pass for those models
# LongformerForSequenceClassification,
# LongformerForQuestionAnswering,
# LongformerForTokenClassification,
# LongformerForMultipleChoice,
)
if is_torch_available()
else ()
)
def setUp(self):
self.model_tester = LongformerModelTester(self)
......
......@@ -29,10 +29,12 @@ if is_torch_available():
RobertaConfig,
RobertaModel,
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
)
from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering
from transformers.modeling_roberta import RobertaEmbeddings
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_utils import create_position_ids_from_input_ids
......@@ -40,7 +42,18 @@ if is_torch_available():
@require_torch
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
all_model_classes = (
(
RobertaForMaskedLM,
RobertaModel,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
)
if is_torch_available()
else ()
)
class RobertaModelTester(object):
def __init__(
......
......@@ -31,6 +31,7 @@ if is_torch_available():
XLNetConfig,
XLNetModel,
XLNetLMHeadModel,
XLNetForMultipleChoice,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetForQuestionAnswering,
......@@ -48,6 +49,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForMultipleChoice,
)
if is_torch_available()
else ()
......@@ -84,6 +86,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
bos_token_id=1,
eos_token_id=2,
pad_token_id=5,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
......@@ -110,6 +113,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment