Unverified Commit 8e67573a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder Tests] Improve tests (#4046)



* Hoist bert model tester for patric

* indent

* make tests work

* Update tests/test_modeling_bert.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
Co-authored-by: default avatarsshleifer <sshleifer@gmail.com>
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 6af3306a
...@@ -38,24 +38,7 @@ if is_torch_available(): ...@@ -38,24 +38,7 @@ if is_torch_available():
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch class BertModelTester:
class BertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
if is_torch_available()
else ()
)
class BertModelTester(object):
def __init__( def __init__(
self, self,
parent, parent,
...@@ -292,10 +275,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -292,10 +275,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, seq_relationship_score = model( loss, seq_relationship_score = model(
input_ids, input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
attention_mask=input_mask,
token_type_ids=token_type_ids,
next_sentence_label=sequence_labels,
) )
result = { result = {
"loss": loss, "loss": loss,
...@@ -374,16 +354,12 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -374,16 +354,12 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
model = BertForTokenClassification(config=config) model = BertForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice( def create_and_check_bert_for_multiple_choice(
...@@ -423,8 +399,26 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -423,8 +399,26 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
if is_torch_available()
else ()
)
def setUp(self): def setUp(self):
self.model_tester = BertModelTest.BertModelTester(self) self.model_tester = BertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
def test_config(self): def test_config(self):
......
...@@ -21,7 +21,7 @@ from transformers import is_torch_available ...@@ -21,7 +21,7 @@ from transformers import is_torch_available
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented # TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest from .test_modeling_bert import BertModelTester
from .utils import require_torch, slow, torch_device from .utils import require_torch, slow, torch_device
...@@ -34,7 +34,7 @@ if is_torch_available(): ...@@ -34,7 +34,7 @@ if is_torch_available():
@require_torch @require_torch
class EncoderDecoderModelTest(unittest.TestCase): class EncoderDecoderModelTest(unittest.TestCase):
def prepare_config_and_inputs_bert(self): def prepare_config_and_inputs_bert(self):
bert_model_tester = BertModelTest.BertModelTester(self) bert_model_tester = BertModelTester(self)
encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs() encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder() decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment