Unverified Commit 0e36e515 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix the tests for Electra (#6284)

* Fix the tests for Electra

* Apply style
parent 6ba540b7
...@@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): ...@@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
super().__init__(config) super().__init__(config)
self.electra = ElectraModel(config) self.electra = ElectraModel(config)
self.summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights() self.init_weights()
...@@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): ...@@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
sequence_output = discriminator_hidden_states[0] sequence_output = discriminator_hidden_states[0]
pooled_output = self.summary(sequence_output) pooled_output = self.sequence_summary(sequence_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices) reshaped_logits = logits.view(-1, num_choices)
......
...@@ -63,6 +63,7 @@ class TFElectraModelTester: ...@@ -63,6 +63,7 @@ class TFElectraModelTester:
self.num_labels = 3 self.num_labels = 3
self.num_choices = 4 self.num_choices = 4
self.scope = None self.scope = None
self.embedding_size = 128
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -194,7 +195,14 @@ class TFElectraModelTester: ...@@ -194,7 +195,14 @@ class TFElectraModelTester:
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,) (
TFElectraModel,
TFElectraForMaskedLM,
TFElectraForPreTraining,
TFElectraForTokenClassification,
TFElectraForMultipleChoice,
TFElectraForSequenceClassification,
)
if is_tf_available() if is_tf_available()
else () else ()
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment