"...audio/git@developer.sourcefind.cn:OpenDAS/lightx2v.git" did not exist on "a1ebc651ab830a381e8960029145b557990342d6"
Unverified Commit 29792864 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ProphetNet] Add Question Generation Model + Test (#7942)

* new prophetnet model

* correct name

* make style
parent 13842e41
...@@ -1073,3 +1073,33 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1073,3 +1073,33 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
[EXPECTED_SUMMARIZE_100], [EXPECTED_SUMMARIZE_100],
generated_titles, generated_titles,
) )
@slow
def test_question_gen_inference(self):
model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
model.to(torch_device)
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
INPUTS = [
"Bill Gates [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
"1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
"April 4, 1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
]
input_ids = tokenizer(INPUTS, truncation=True, padding=True, return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
gen_output = model.generate(input_ids, num_beams=5, early_stopping=True)
generated_questions = tokenizer.batch_decode(gen_output, skip_special_tokens=True)
EXPECTED_QUESTIONS = [
"along with paul allen, who founded microsoft?",
"what year was microsoft founded?",
"on what date was microsoft founded?",
]
self.assertListEqual(
EXPECTED_QUESTIONS,
generated_questions,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment