Unverified Commit 111c6814 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

fix: Send last token batch when finish_reason is set (#3531)

parent 03bdced9
...@@ -188,26 +188,24 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -188,26 +188,24 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Yields: Yields:
Dict with token_ids and optional finish_reason. Dict with token_ids and optional finish_reason.
Raises:
ValueError: If response missing output_ids.
""" """
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for res in stream_source: async for res in stream_source:
out = {}
finish_reason = res["meta_info"]["finish_reason"] finish_reason = res["meta_info"]["finish_reason"]
if finish_reason: if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]} out["finish_reason"] = finish_reason["type"]
else:
try: output_ids = res.get("output_ids", [])
next_total_toks = len(res["output_ids"]) # If request is not finished yet, but there are no outputs, return an error.
except KeyError: if not output_ids and not finish_reason:
raise ValueError( yield {"finish_reason": "error", "token_ids": []}
f"Missing 'output_ids' in response. Response keys: {list(res.keys())}" break
)
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks
next_total_toks = len(output_ids)
out["token_ids"] = output_ids[num_output_tokens_so_far:]
num_output_tokens_so_far = next_total_toks
yield out yield out
async def _process_text_stream( async def _process_text_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment