"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5451f8896c23f006648aa8da852fec499dfe6000"
Unverified Commit cbc1abc4 authored by Ankur Goyal's avatar Ankur Goyal Committed by GitHub
Browse files

A few CI fixes for `DocumentQuestionAnsweringPipeline` (#19584)



* Fixes

* update expected values

* style

* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0b7b07ef
...@@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
`word_boxes`). `word_boxes`).
- **answer** (`str`) -- The answer to the question. - **answer** (`str`) -- The answer to the question.
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer - **words** (`list[int]`) -- The index of each word/box pair that is in the answer
- **page** (`int`) -- The page of the answer
""" """
if isinstance(question, str): if isinstance(question, str):
inputs = {"question": question, "image": image} inputs = {"question": question, "image": image}
...@@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
"p_mask": None, "p_mask": None,
"word_ids": None, "word_ids": None,
"words": None, "words": None,
"page": None,
"output_attentions": True, "output_attentions": True,
"is_last": True, "is_last": True,
} }
...@@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
return_overflowing_tokens=True, return_overflowing_tokens=True,
**tokenizer_kwargs, **tokenizer_kwargs,
) )
encoding.pop("overflow_to_sample_mapping") # We do not use this
num_spans = len(encoding["input_ids"]) num_spans = len(encoding["input_ids"])
...@@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
words = model_inputs.pop("words", None) words = model_inputs.pop("words", None)
is_last = model_inputs.pop("is_last", False) is_last = model_inputs.pop("is_last", False)
if "overflow_to_sample_mapping" in model_inputs:
model_inputs.pop("overflow_to_sample_mapping")
if self.model_type == ModelType.VisionEncoderDecoder: if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs) model_outputs = self.model.generate(**model_inputs)
else: else:
...@@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
return answers return answers
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs): def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context). # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
......
...@@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
...@@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
...@@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
] ]
] ]
* 2, * 2,
...@@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
...@@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
] ]
] ]
* 2, * 2,
...@@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ...@@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment