Unverified Commit 44b1b394 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[PD-Disagg] Check finish after pop tranferred (#12638)

parent 0711d150
......@@ -773,17 +773,12 @@ class DecodeTransferQueue:
indices_to_remove.add(i)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
# special handling for corner cases
should_finish = decode_req.req.sampling_params.max_new_tokens == 1 or (
not decode_req.req.sampling_params.ignore_eos
and decode_req.req.output_ids[-1] in decode_req.req.eos_token_ids
)
if should_finish:
decode_req.req.check_finished()
if decode_req.req.finished():
# finish immediately
decode_req.req.time_stats.forward_entry_time = (
decode_req.req.time_stats.completion_time
) = time.perf_counter()
decode_req.req.check_finished()
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
......
......@@ -3,7 +3,9 @@ import os
import unittest
from types import SimpleNamespace
import openai
import requests
from transformers import AutoTokenizer
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
......@@ -136,6 +138,52 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
# ensure the output is a valid JSON
json.loads(output)
def test_first_token_finish(self):
client = openai.Client(api_key="empty", base_url=f"{self.lb_url}/v1")
tokenizer = AutoTokenizer.from_pretrained(self.model)
eos_token = tokenizer.eos_token_id
prompt = "The best programming language for AI is"
# First token EOS
res = client.completions.create(
model="dummy", prompt=prompt, logit_bias={eos_token: 42}
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] == 1, (
"Expected completion_tokens to be 1 when first token is EOS, "
f"but got {res['usage']['completion_tokens']}"
)
# First token EOS with ignore_eos
res = client.completions.create(
model="dummy",
prompt=prompt,
logit_bias={eos_token: 42},
extra_body={"ignore_eos": True},
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] > 1, (
"Expected completion_tokens to be greater than 1 when ignore_eos is True, "
f"but got {res['usage']['completion_tokens']}"
)
# First token with specified stop token
stop_token_id = tokenizer.encode(" hello", add_special_tokens=False)[0]
res = client.completions.create(
model="dummy",
prompt=prompt,
logit_bias={stop_token_id: 42},
stop=[" hello"],
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] == 1, (
"Expected completion_tokens to be 1 when first token is stop token, "
f"but got {res['usage']['completion_tokens']}"
)
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment