Unverified Commit 44b1b394 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[PD-Disagg] Check finish after pop tranferred (#12638)

parent 0711d150
...@@ -773,17 +773,12 @@ class DecodeTransferQueue: ...@@ -773,17 +773,12 @@ class DecodeTransferQueue:
indices_to_remove.add(i) indices_to_remove.add(i)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter() decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
# special handling for corner cases decode_req.req.check_finished()
should_finish = decode_req.req.sampling_params.max_new_tokens == 1 or ( if decode_req.req.finished():
not decode_req.req.sampling_params.ignore_eos
and decode_req.req.output_ids[-1] in decode_req.req.eos_token_ids
)
if should_finish:
# finish immediately # finish immediately
decode_req.req.time_stats.forward_entry_time = ( decode_req.req.time_stats.forward_entry_time = (
decode_req.req.time_stats.completion_time decode_req.req.time_stats.completion_time
) = time.perf_counter() ) = time.perf_counter()
decode_req.req.check_finished()
self.scheduler.stream_output( self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob [decode_req.req], decode_req.req.return_logprob
) )
......
...@@ -3,7 +3,9 @@ import os ...@@ -3,7 +3,9 @@ import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import openai
import requests import requests
from transformers import AutoTokenizer
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_disaggregation_utils import TestDisaggregationBase
...@@ -136,6 +138,52 @@ class TestDisaggregationAccuracy(TestDisaggregationBase): ...@@ -136,6 +138,52 @@ class TestDisaggregationAccuracy(TestDisaggregationBase):
# ensure the output is a valid JSON # ensure the output is a valid JSON
json.loads(output) json.loads(output)
def test_first_token_finish(self):
client = openai.Client(api_key="empty", base_url=f"{self.lb_url}/v1")
tokenizer = AutoTokenizer.from_pretrained(self.model)
eos_token = tokenizer.eos_token_id
prompt = "The best programming language for AI is"
# First token EOS
res = client.completions.create(
model="dummy", prompt=prompt, logit_bias={eos_token: 42}
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] == 1, (
"Expected completion_tokens to be 1 when first token is EOS, "
f"but got {res['usage']['completion_tokens']}"
)
# First token EOS with ignore_eos
res = client.completions.create(
model="dummy",
prompt=prompt,
logit_bias={eos_token: 42},
extra_body={"ignore_eos": True},
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] > 1, (
"Expected completion_tokens to be greater than 1 when ignore_eos is True, "
f"but got {res['usage']['completion_tokens']}"
)
# First token with specified stop token
stop_token_id = tokenizer.encode(" hello", add_special_tokens=False)[0]
res = client.completions.create(
model="dummy",
prompt=prompt,
logit_bias={stop_token_id: 42},
stop=[" hello"],
).model_dump()
print(f"{res=}")
assert res["usage"]["completion_tokens"] == 1, (
"Expected completion_tokens to be 1 when first token is stop token, "
f"but got {res['usage']['completion_tokens']}"
)
class TestDisaggregationMooncakeFailure(TestDisaggregationBase): class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment