"mmdet3d/ops/vscode:/vscode.git/clone" did not exist on "b107238d91d505c1c7a8f982a5e56c319df8c9f1"
Unverified Commit 8e67573a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder Tests] Improve tests (#4046)



* Hoist bert model tester for patric

* indent

* make tests work

* Update tests/test_modeling_bert.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
Co-authored-by: default avatarsshleifer <sshleifer@gmail.com>
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 6af3306a
......@@ -38,24 +38,7 @@ if is_torch_available():
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
if is_torch_available()
else ()
)
class BertModelTester(object):
class BertModelTester:
def __init__(
self,
parent,
......@@ -292,10 +275,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
loss, seq_relationship_score = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
next_sentence_label=sequence_labels,
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
......@@ -374,16 +354,12 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
model = BertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice(
......@@ -423,8 +399,26 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
if is_torch_available()
else ()
)
def setUp(self):
self.model_tester = BertModelTest.BertModelTester(self)
self.model_tester = BertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
def test_config(self):
......
......@@ -21,7 +21,7 @@ from transformers import is_torch_available
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest
from .test_modeling_bert import BertModelTester
from .utils import require_torch, slow, torch_device
......@@ -34,7 +34,7 @@ if is_torch_available():
@require_torch
class EncoderDecoderModelTest(unittest.TestCase):
def prepare_config_and_inputs_bert(self):
bert_model_tester = BertModelTest.BertModelTester(self)
bert_model_tester = BertModelTester(self)
encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment