Unverified Commit ebba39e4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bart] Question Answering Model is added to tests (#5024)

* fix test

* Update tests/test_modeling_common.py

* Update tests/test_modeling_common.py
parent bbad4c69
...@@ -113,7 +113,9 @@ def prepare_bart_inputs_dict( ...@@ -113,7 +113,9 @@ def prepare_bart_inputs_dict(
@require_torch @require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase): class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()
else ()
) )
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
BertConfig, BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
...@@ -180,8 +181,13 @@ class ModelTesterMixin: ...@@ -180,8 +181,13 @@ class ModelTesterMixin:
correct_outlen = 4 correct_outlen = 4
decoder_attention_idx = 1 decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first # loss is at first position
correct_outlen += 1 # compute loss if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
decoder_attention_idx += 1
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
decoder_attention_idx += 1 decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen) self.assertEqual(out_len, correct_outlen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment