Unverified Commit 14b04b4b authored by Matt's avatar Matt Committed by GitHub
Browse files

Conversation pipeline fixes (#26795)

* Adjust length limits and allow naked conversation list inputs

* Adjust length limits and allow naked conversation list inputs

* Maybe use a slightly more reasonable limit than 1024

* Skip tests for old models that never supported this anyway

* Cleanup input docstrings

* More docstring cleanup + skip failing TF test

* Make fixup
parent 5c6b83cb
...@@ -247,13 +247,15 @@ class ConversationalPipeline(Pipeline): ...@@ -247,13 +247,15 @@ class ConversationalPipeline(Pipeline):
forward_params.update(generate_kwargs) forward_params.update(generate_kwargs)
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs): def __call__(self, conversations: Union[List[Dict], Conversation, List[Conversation]], num_workers=0, **kwargs):
r""" r"""
Generate responses for the conversation(s) given as inputs. Generate responses for the conversation(s) given as inputs.
Args: Args:
conversations (a [`Conversation`] or a list of [`Conversation`]): conversations (a [`Conversation`] or a list of [`Conversation`]):
Conversations to generate responses for. Conversation to generate responses for. Inputs can also be passed as a list of dictionaries with `role`
and `content` keys - in this case, they will be converted to `Conversation` objects automatically.
Multiple conversations in either format may be passed as a list.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
generate_kwargs: generate_kwargs:
...@@ -268,6 +270,10 @@ class ConversationalPipeline(Pipeline): ...@@ -268,6 +270,10 @@ class ConversationalPipeline(Pipeline):
# Otherwise the threads will require a Conversation copy. # Otherwise the threads will require a Conversation copy.
# This will definitely hinder performance on GPU, but has to be opted # This will definitely hinder performance on GPU, but has to be opted
# in because of this BC change. # in because of this BC change.
if isinstance(conversations, list) and isinstance(conversations[0], dict):
conversations = Conversation(conversations)
elif isinstance(conversations, list) and isinstance(conversations[0], list):
conversations = [Conversation(conv) for conv in conversations]
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs) outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
if isinstance(outputs, list) and len(outputs) == 1: if isinstance(outputs, list) and len(outputs) == 1:
return outputs[0] return outputs[0]
...@@ -283,19 +289,10 @@ class ConversationalPipeline(Pipeline): ...@@ -283,19 +289,10 @@ class ConversationalPipeline(Pipeline):
return {"input_ids": input_ids, "conversation": conversation} return {"input_ids": input_ids, "conversation": conversation}
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs): def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
n = model_inputs["input_ids"].shape[1] n = model_inputs["input_ids"].shape[1]
if max_length - minimum_tokens < n:
logger.warning(
f"Conversation input is too long ({n}), trimming it to {max_length - minimum_tokens} tokens. Consider increasing `max_length` to avoid truncation."
)
trim = max_length - minimum_tokens
model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:]
if "attention_mask" in model_inputs:
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:]
conversation = model_inputs.pop("conversation") conversation = model_inputs.pop("conversation")
generate_kwargs["max_length"] = max_length if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
generate_kwargs["max_new_tokens"] = 256
output_ids = self.model.generate(**model_inputs, **generate_kwargs) output_ids = self.model.generate(**model_inputs, **generate_kwargs)
if self.model.config.is_encoder_decoder: if self.model.config.is_encoder_decoder:
start_position = 1 start_position = 1
......
...@@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model.generate(input_ids, attention_mask=attention_mask) model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
@unittest.skip("Does not support conversations.")
def test_pipeline_conversational(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""): def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
...@@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester ...@@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
# check that the output for the restored model is the same # check that the output for the restored model is the same
self.assert_outputs_same(restored_model_outputs, outputs) self.assert_outputs_same(restored_model_outputs, outputs)
@unittest.skip("Does not support conversations.")
def test_pipeline_conversational(self):
pass
def _long_tensor(tok_lst): def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32) return tf.constant(tok_lst, dtype=tf.int32)
......
...@@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_disk_offload(self): def test_disk_offload(self):
pass pass
@unittest.skip("Does not support conversations.")
def test_pipeline_conversational(self):
pass
class T5EncoderOnlyModelTester: class T5EncoderOnlyModelTester:
def __init__( def __init__(
......
...@@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_keras_save_load(self): def test_keras_save_load(self):
pass pass
@unittest.skip("Does not support conversations.")
def test_pipeline_conversational(self):
pass
class TFT5EncoderOnlyModelTester: class TFT5EncoderOnlyModelTester:
def __init__( def __init__(
...@@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): ...@@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"] expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings) self.assertListEqual(expected_output_string, output_strings)
@unittest.skip("Does not support conversations.")
def test_pipeline_conversational(self):
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment