Unverified Commit 18e9e1f7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[HotFix] Fix final output truncation with stop string + streaming (#8468)

parent f57092c0
......@@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
......@@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)
output_count = 0
......@@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine):
@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
......@@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)
async def run(prompt: str, kind: RequestOutputKind):
......@@ -229,6 +233,8 @@ async def test_output_kinds(async_engine):
final_output = output
assert final_output is not None
assert final_output.finished
return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
......@@ -241,16 +247,18 @@ async def test_output_kinds(async_engine):
output_tokens: List[int] = []
output_text = ""
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
final_output = output
# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert stop or text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
......@@ -260,6 +268,10 @@ async def test_output_kinds(async_engine):
output_text += text
output_count += 1
assert final_output is not None
assert final_output.finished
return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather(
......@@ -291,7 +303,8 @@ async def test_output_kinds(async_engine):
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
......@@ -299,6 +312,7 @@ async def test_cancellation(async_engine):
temperature=0,
min_tokens=13,
max_tokens=13,
stop=stop,
)
stop_at = 5 if num_scheduler_steps == 1 else 1
......@@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
......@@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
temperature=0,
min_tokens=10,
max_tokens=10,
stop=stop,
)
stream = async_engine.generate("test3", sampling_params, request_id=uid())
......
......@@ -477,7 +477,9 @@ class Sequence:
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text) - buffer_length
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment