"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ae9a344cce52ff244f721425f660b55ebc522b88"
Unverified Commit 06107541 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline (#15318)

* Fixing support `batch_size` and `num_return_Sequences` in
`text-generation` pipeline

And `text2text-generation` too.

The bug was caused by the batch_size containing both the incoming batch
**and** the generated `num_sequences`.

The fix simply consists into splitting both of these again into
different dimensions.

* TF support.

* Odd backward compatibility script in the way.
parent c4d1fd77
...@@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline):
""" """
result = super().__call__(*args, **kwargs) result = super().__call__(*args, **kwargs)
if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]): if (
isinstance(args[0], list)
and all(isinstance(el, str) for el in args[0])
and all(len(res) == 1 for res in result)
):
return [res[0] for res in result] return [res[0] for res in result]
return result return result
...@@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt": if self.framework == "pt":
input_length = model_inputs["input_ids"].shape[-1] in_b, input_length = model_inputs["input_ids"].shape
elif self.framework == "tf": elif self.framework == "tf":
input_length = tf.shape(model_inputs["input_ids"])[-1].numpy() in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output_ids = self.model.generate(**model_inputs, **generate_kwargs) output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
if self.framework == "pt":
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
elif self.framework == "tf":
output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
return {"output_ids": output_ids} return {"output_ids": output_ids}
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False): def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
records = [] records = []
for output_ids in model_outputs["output_ids"]: for output_ids in model_outputs["output_ids"][0]:
if return_type == ReturnType.TENSORS: if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs} record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT: elif return_type == ReturnType.TEXT:
......
...@@ -2,10 +2,14 @@ import enum ...@@ -2,10 +2,14 @@ import enum
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..file_utils import add_end_docstrings from ..file_utils import add_end_docstrings, is_tf_available
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
if is_tf_available():
import tensorflow as tf
class ReturnType(enum.Enum): class ReturnType(enum.Enum):
TENSORS = 0 TENSORS = 0
NEW_TEXT = 1 NEW_TEXT = 1
...@@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline): ...@@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline):
# Allow empty prompts # Allow empty prompts
if input_ids.shape[1] == 0: if input_ids.shape[1] == 0:
input_ids = None input_ids = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text") prompt_text = model_inputs.pop("prompt_text")
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
generated_sequence = model_outputs["generated_sequence"] generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"] input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"] prompt_text = model_outputs["prompt_text"]
if self.framework == "pt" and generated_sequence is not None:
generated_sequence = generated_sequence.cpu()
generated_sequence = generated_sequence.numpy().tolist() generated_sequence = generated_sequence.numpy().tolist()
if return_type == ReturnType.TENSORS: records = []
record = {"generated_token_ids": generated_sequence} for sequence in generated_sequence:
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: if return_type == ReturnType.TENSORS:
# Decode text record = {"generated_token_ids": generated_sequence}
record = [] elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
for sequence in generated_sequence: # Decode text
text = self.tokenizer.decode( text = self.tokenizer.decode(
sequence, sequence,
skip_special_tokens=True, skip_special_tokens=True,
...@@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline):
else: else:
all_text = text[prompt_length:] all_text = text[prompt_length:]
item = {"generated_text": all_text} record = {"generated_text": all_text}
record.append(item) records.append(record)
return record return records
...@@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
# These are encoder decoder, they don't just append to incoming string # These are encoder decoder, they don't just append to incoming string
self.assertFalse(outputs[0]["generated_text"].startswith("Something there")) self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
outputs = generator(
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
generator(4) generator(4)
......
...@@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
self.assertEqual(outputs, [{"generated_text": ANY(str)}]) self.assertEqual(outputs, [{"generated_text": ANY(str)}])
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test")) self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
if text_generator.tokenizer.pad_token is not None:
outputs = text_generator(
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
# Empty prompt is slighly special # Empty prompt is slighly special
# it requires BOS token to exist. # it requires BOS token to exist.
# Special case for Pegasus which will always append EOS so will # Special case for Pegasus which will always append EOS so will
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment