"ml/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "5d22953ba7913cde3f791ba4aa4ae6c55f3f56bf"
Unverified Commit 4e10acb3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add more models to common tests (#4910)

parent 3b3619a3
...@@ -848,7 +848,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -848,7 +848,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss # Only keep active parts of the loss
......
...@@ -435,7 +435,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -435,7 +435,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
sequence_output = discriminator_hidden_states[0] sequence_output = discriminator_hidden_states[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + discriminator_hidden_states[2:] # add hidden states and attention if they are here outputs = (logits,) + discriminator_hidden_states[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
......
...@@ -797,6 +797,8 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -797,6 +797,8 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self.longformer = LongformerModel(config) self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config) self.classifier = LongformerClassificationHead(config)
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
...@@ -861,6 +863,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -861,6 +863,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -919,7 +922,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -919,7 +922,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
...@@ -1099,6 +1102,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1099,6 +1102,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1228,6 +1232,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): ...@@ -1228,6 +1232,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask, global_attention_mask=flat_global_attention_mask,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
......
...@@ -300,6 +300,8 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -300,6 +300,8 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config) self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
...@@ -618,7 +620,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -618,7 +620,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
......
...@@ -38,7 +38,13 @@ if is_torch_available(): ...@@ -38,7 +38,13 @@ if is_torch_available():
class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) (
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
)
if is_torch_available() if is_torch_available()
else None else None
) )
......
...@@ -39,7 +39,15 @@ if is_torch_available(): ...@@ -39,7 +39,15 @@ if is_torch_available():
class ElectraModelTest(ModelTesterMixin, unittest.TestCase): class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(ElectraModel, ElectraForMaskedLM, ElectraForTokenClassification,) if is_torch_available() else () (
ElectraModel,
ElectraForPreTraining,
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraForSequenceClassification,
)
if is_torch_available()
else ()
) )
class ElectraModelTester(object): class ElectraModelTester(object):
......
...@@ -296,7 +296,19 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -296,7 +296,19 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False # head masking is not supported test_headmasking = False # head masking is not supported
test_torchscript = False test_torchscript = False
all_model_classes = (LongformerModel, LongformerForMaskedLM,) if is_torch_available() else () all_model_classes = (
(
LongformerModel,
LongformerForMaskedLM,
# TODO: make tests pass for those models
# LongformerForSequenceClassification,
# LongformerForQuestionAnswering,
# LongformerForTokenClassification,
# LongformerForMultipleChoice,
)
if is_torch_available()
else ()
)
def setUp(self): def setUp(self):
self.model_tester = LongformerModelTester(self) self.model_tester = LongformerModelTester(self)
......
...@@ -29,10 +29,12 @@ if is_torch_available(): ...@@ -29,10 +29,12 @@ if is_torch_available():
RobertaConfig, RobertaConfig,
RobertaModel, RobertaModel,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
) )
from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering from transformers.modeling_roberta import RobertaEmbeddings
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_utils import create_position_ids_from_input_ids from transformers.modeling_utils import create_position_ids_from_input_ids
...@@ -40,7 +42,18 @@ if is_torch_available(): ...@@ -40,7 +42,18 @@ if is_torch_available():
@require_torch @require_torch
class RobertaModelTest(ModelTesterMixin, unittest.TestCase): class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else () all_model_classes = (
(
RobertaForMaskedLM,
RobertaModel,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
)
if is_torch_available()
else ()
)
class RobertaModelTester(object): class RobertaModelTester(object):
def __init__( def __init__(
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
XLNetConfig, XLNetConfig,
XLNetModel, XLNetModel,
XLNetLMHeadModel, XLNetLMHeadModel,
XLNetForMultipleChoice,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetForTokenClassification, XLNetForTokenClassification,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
...@@ -48,6 +49,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -48,6 +49,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
XLNetForTokenClassification, XLNetForTokenClassification,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -84,6 +86,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -84,6 +86,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
pad_token_id=5, pad_token_id=5,
num_choices=4,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -110,6 +113,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -110,6 +113,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.num_choices = num_choices
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment