"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "7c0c11b776caf7a30e5c8aabe3116e41321fc849"
Unverified Commit bdbcd5d4 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix and re-enable ConversationalPipeline tests (#26907)

* Fix and re-enable conversationalpipeline tests

* Fix the batch test so the change only applies to conversational pipeline
parent 734dd96e
...@@ -77,14 +77,14 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -77,14 +77,14 @@ class ConversationalPipelineTests(unittest.TestCase):
def run_pipeline_test(self, conversation_agent, _): def run_pipeline_test(self, conversation_agent, _):
# Simple # Simple
outputs = conversation_agent(Conversation("Hi there!")) outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
self.assertEqual( self.assertEqual(
outputs, outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
) )
# Single list # Single list
outputs = conversation_agent([Conversation("Hi there!")]) outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
self.assertEqual( self.assertEqual(
outputs, outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
...@@ -96,7 +96,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -96,7 +96,7 @@ class ConversationalPipelineTests(unittest.TestCase):
self.assertEqual(len(conversation_1), 1) self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2), 1) self.assertEqual(len(conversation_2), 1)
outputs = conversation_agent([conversation_1, conversation_2]) outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
self.assertEqual(outputs, [conversation_1, conversation_2]) self.assertEqual(outputs, [conversation_1, conversation_2])
self.assertEqual( self.assertEqual(
outputs, outputs,
...@@ -118,7 +118,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -118,7 +118,7 @@ class ConversationalPipelineTests(unittest.TestCase):
# One conversation with history # One conversation with history
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"}) conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2) outputs = conversation_agent(conversation_2, max_new_tokens=20)
self.assertEqual(outputs, conversation_2) self.assertEqual(outputs, conversation_2)
self.assertEqual( self.assertEqual(
outputs, outputs,
......
...@@ -312,8 +312,12 @@ class PipelineTesterMixin: ...@@ -312,8 +312,12 @@ class PipelineTesterMixin:
yield copy.deepcopy(random.choice(examples)) yield copy.deepcopy(random.choice(examples))
out = [] out = []
for item in pipeline(data(10), batch_size=4): if task == "conversational":
out.append(item) for item in pipeline(data(10), batch_size=4, max_new_tokens=20):
out.append(item)
else:
for item in pipeline(data(10), batch_size=4):
out.append(item)
self.assertEqual(len(out), 10) self.assertEqual(len(out), 10)
run_batch_test(pipeline, examples) run_batch_test(pipeline, examples)
...@@ -327,7 +331,6 @@ class PipelineTesterMixin: ...@@ -327,7 +331,6 @@ class PipelineTesterMixin:
self.run_task_tests(task="automatic-speech-recognition") self.run_task_tests(task="automatic-speech-recognition")
@is_pipeline_test @is_pipeline_test
@unittest.skip("Conversational tests are currently broken for several models, will fix ASAP - Matt")
def test_pipeline_conversational(self): def test_pipeline_conversational(self):
self.run_task_tests(task="conversational") self.run_task_tests(task="conversational")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment