Unverified Commit 18e9e1f7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[HotFix] Fix final output truncation with stop string + streaming (#8468)

parent f57092c0
...@@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool: ...@@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps num_scheduler_steps = scheduler_config.num_scheduler_steps
...@@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine): ...@@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine):
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
min_tokens=32, min_tokens=32,
stop=stop,
) )
output_count = 0 output_count = 0
...@@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine): ...@@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that """Test that output_kind works as expected and that
results are equivalent across different kinds.""" results are equivalent across different kinds."""
...@@ -214,6 +217,7 @@ async def test_output_kinds(async_engine): ...@@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
min_tokens=32, min_tokens=32,
stop=stop,
) )
async def run(prompt: str, kind: RequestOutputKind): async def run(prompt: str, kind: RequestOutputKind):
...@@ -229,6 +233,8 @@ async def test_output_kinds(async_engine): ...@@ -229,6 +233,8 @@ async def test_output_kinds(async_engine):
final_output = output final_output = output
assert final_output is not None assert final_output is not None
assert final_output.finished
return (final_output.prompt_token_ids, return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids, final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count) final_output.outputs[0].text, output_count)
...@@ -241,16 +247,18 @@ async def test_output_kinds(async_engine): ...@@ -241,16 +247,18 @@ async def test_output_kinds(async_engine):
output_tokens: List[int] = [] output_tokens: List[int] = []
output_text = "" output_text = ""
output_count = 0 output_count = 0
final_output = None
async for output in async_engine.generate(prompt, async for output in async_engine.generate(prompt,
params, params,
request_id=uid()): request_id=uid()):
token_ids = output.outputs[0].token_ids token_ids = output.outputs[0].token_ids
text = output.outputs[0].text text = output.outputs[0].text
final_output = output
# Ensure we get prompt ids iff we haven't yet received output tokens # Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens: if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps assert 1 <= len(token_ids) <= num_scheduler_steps
assert text assert stop or text
assert not output.prompt_token_ids assert not output.prompt_token_ids
else: else:
assert output.prompt_token_ids assert output.prompt_token_ids
...@@ -260,6 +268,10 @@ async def test_output_kinds(async_engine): ...@@ -260,6 +268,10 @@ async def test_output_kinds(async_engine):
output_text += text output_text += text
output_count += 1 output_count += 1
assert final_output is not None
assert final_output.finished
return prompt_tokens, output_tokens, output_text, output_count return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather( results = await asyncio.gather(
...@@ -291,7 +303,8 @@ async def test_output_kinds(async_engine): ...@@ -291,7 +303,8 @@ async def test_output_kinds(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps num_scheduler_steps = scheduler_config.num_scheduler_steps
...@@ -299,6 +312,7 @@ async def test_cancellation(async_engine): ...@@ -299,6 +312,7 @@ async def test_cancellation(async_engine):
temperature=0, temperature=0,
min_tokens=13, min_tokens=13,
max_tokens=13, max_tokens=13,
stop=stop,
) )
stop_at = 5 if num_scheduler_steps == 1 else 1 stop_at = 5 if num_scheduler_steps == 1 else 1
...@@ -319,7 +333,8 @@ async def test_cancellation(async_engine): ...@@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1: if scheduler_config.num_scheduler_steps != 1:
...@@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine): ...@@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
temperature=0, temperature=0,
min_tokens=10, min_tokens=10,
max_tokens=10, max_tokens=10,
stop=stop,
) )
stream = async_engine.generate("test3", sampling_params, request_id=uid()) stream = async_engine.generate("test3", sampling_params, request_id=uid())
......
...@@ -477,7 +477,9 @@ class Sequence: ...@@ -477,7 +477,9 @@ class Sequence:
if not delta: if not delta:
return self.output_text[:-buffer_length] if truncate else ( return self.output_text[:-buffer_length] if truncate else (
self.output_text) self.output_text)
length = len(self.output_text) - buffer_length length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset last_offset = self._last_output_text_offset
if last_offset < length: if last_offset < length:
self._last_output_text_offset = length self._last_output_text_offset = length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment