"vscode:/vscode.git/clone" did not exist on "211f93aab95d1c683494e61c3cf8ff10e1f5d6b7"
Unverified Commit a4562552 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[DX fix] Fixing QA pipeline streaming a dataset. (#18516)

* [DX fix] Fixing QA pipeline streaming a dataset.

QuestionAnsweringArgumentHandler would iterate over the whole dataset
effectively killing all properties of the pipeline.
This restores nice properties when using `Dataset` or `Generator` since
those are meant to be consumed lazily.

* Handling TF better.
parent 88a0ce57
import types
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
......@@ -22,8 +23,11 @@ if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
Dataset = None
if is_torch_available():
import torch
from torch.utils.data import Dataset
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
......@@ -82,6 +86,11 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
else:
raise ValueError(f"Unknown arguments {kwargs}")
# When user is sending a generator we need to trust it's a valid example
generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
if isinstance(inputs, generator_types):
return inputs
# Normalize inputs
if isinstance(inputs, dict):
inputs = [inputs]
......@@ -245,12 +254,18 @@ class QuestionAnsweringPipeline(ChunkPipeline):
"""
# Convert inputs to features
examples = self._args_parser(*args, **kwargs)
if len(examples) == 1:
if isinstance(examples, (list, tuple)) and len(examples) == 1:
return super().__call__(examples[0], **kwargs)
return super().__call__(examples, **kwargs)
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
# XXX: This is specal, args_parser will not handle anything generator or dataset like
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
# So we still need a little sanitation here.
if isinstance(example, dict):
example = SquadExample(None, example["question"], example["context"], None, None, None)
if max_seq_len is None:
max_seq_len = min(self.tokenizer.model_max_length, 384)
......
......@@ -125,6 +125,18 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_torch
def test_small_model_pt_iterator(self):
# https://github.com/huggingface/transformers/issues/18510
pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")
def data():
for i in range(10):
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
for outputs in pipe(data()):
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_torch
def test_small_model_pt_softmax_trick(self):
question_answerer = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment