Unverified Commit ebba39e4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bart] Question Answering Model is added to tests (#5024)

* fix test

* Update tests/test_modeling_common.py

* Update tests/test_modeling_common.py
parent bbad4c69
......@@ -113,7 +113,9 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
......
......@@ -38,6 +38,7 @@ if is_torch_available():
BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
top_k_top_p_filtering,
)
......@@ -180,8 +181,13 @@ class ModelTesterMixin:
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
decoder_attention_idx += 1
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment