Unverified Commit caa5d296 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

feat: return partial generation results when aborting requests in waiting queue (#11673)

parent 750940ae
......@@ -86,6 +86,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` |
| `SGLANG_IS_IN_CI_AMD` | Indicates running in AMD CI environment | `0` |
| `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` |
| `SGLANG_TEST_RETRACT_NO_PREFILL_BS` | When SGLANG_TEST_RETRACT is enabled, no prefill is performed if the batch size exceeds SGLANG_TEST_RETRACT_NO_PREFILL_BS. | `2 ** 31` |
| `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` |
| `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` |
| `SGLANG_CI_SMALL_KV_SIZE` | Use small KV cache size in CI | Not set |
......
......@@ -135,6 +135,7 @@ class Envs:
# Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
# Scheduler: new token ratio hyperparameters
......
......@@ -198,6 +198,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
......@@ -1657,6 +1658,12 @@ class Scheduler(
# Get priority queue
self.policy.calc_priority(self.waiting_queue)
if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
# If we are testing retraction and the running batch size exceeds
# TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
# in the waiting queue.
return None
# Prefill policy
adder = PrefillAdder(
self.page_size,
......
......@@ -1487,6 +1487,51 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict)
def add_logprob_to_meta_info(
self,
meta_info: dict,
state: ReqState,
top_logprobs_num: int,
token_ids_logprob: List[int],
return_text_in_logprobs: bool,
):
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
state.input_token_logprobs_val,
state.input_token_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
state.output_token_logprobs_val,
state.output_token_logprobs_idx,
return_text_in_logprobs,
)
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_top_logprobs_val,
state.input_top_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.output_top_logprobs_val,
state.output_top_logprobs_idx,
return_text_in_logprobs,
)
if token_ids_logprob is not None:
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_token_ids_logprobs_val,
state.input_token_ids_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens(
state.output_token_ids_logprobs_val,
state.output_token_ids_logprobs_idx,
return_text_in_logprobs,
)
)
def convert_logprob_style(
self,
meta_info: dict,
......@@ -1513,16 +1558,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_token_logprobs_idx.extend(
recv_obj.output_token_logprobs_idx[recv_obj_index]
)
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
state.input_token_logprobs_val,
state.input_token_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
state.output_token_logprobs_val,
state.output_token_logprobs_idx,
return_text_in_logprobs,
)
if top_logprobs_num > 0:
if len(recv_obj.input_top_logprobs_val) > 0:
......@@ -1538,16 +1573,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_top_logprobs_idx.extend(
recv_obj.output_top_logprobs_idx[recv_obj_index]
)
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_top_logprobs_val,
state.input_top_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.output_top_logprobs_val,
state.output_top_logprobs_idx,
return_text_in_logprobs,
)
if token_ids_logprob is not None:
if len(recv_obj.input_token_ids_logprobs_val) > 0:
......@@ -1563,18 +1588,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_token_ids_logprobs_idx.extend(
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
)
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_token_ids_logprobs_val,
state.input_token_ids_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens(
state.output_token_ids_logprobs_val,
state.output_token_ids_logprobs_idx,
return_text_in_logprobs,
)
)
self.add_logprob_to_meta_info(
meta_info,
state,
state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
return_text_in_logprobs,
)
def detokenize_logprob_tokens(
self,
......@@ -1759,26 +1780,33 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return
state = self.rid_to_state[recv_obj.rid]
state.finished = True
abort_message = recv_obj.abort_reason or "Abort in waiting queue"
finish_reason = {
"type": "abort",
"message": abort_message,
}
if recv_obj.finished_reason:
out = {
"meta_info": {
"id": recv_obj.rid,
"finish_reason": recv_obj.finished_reason,
},
}
else:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
}
finish_reason = recv_obj.finished_reason
meta_info = {"id": recv_obj.rid, "finish_reason": finish_reason}
is_stream = getattr(state.obj, "stream", False)
if getattr(state.obj, "return_logprob", False):
self.add_logprob_to_meta_info(
meta_info,
state,
state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs
and not self.server_args.skip_tokenizer_init,
)
output_ids = state.output_ids
meta_info["completion_tokens"] = len(output_ids)
out = {
"text": state.text,
"output_ids": [output_ids[-1]] if is_stream else output_ids,
"meta_info": meta_info,
}
state.out_list.append(out)
state.event.set()
......
......@@ -71,7 +71,7 @@ suites = {
TestFile("rl/test_fp32_lm_head.py", 30),
TestFile("rl/test_update_weights_from_disk.py", 114),
TestFile("rl/test_update_weights_from_tensor.py", 48),
TestFile("test_abort.py", 51),
TestFile("test_abort.py", 121),
TestFile("test_build_eagle_tree.py", 8),
TestFile("test_chunked_prefill.py", 313),
TestFile("test_create_kvindices.py", 2),
......
import json
import multiprocessing
import os
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
......@@ -110,5 +112,82 @@ class TestAbortAll(CustomTestCase):
)
class TestAbortAllWithRetraction(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
# Here's a small trick: in scheduler.py, when SGLANG_TEST_RETRACT is enabled,
# retraction is triggered when the batch size reaches 10.
# However, since SGLANG_TEST_RETRACT_NO_PREFILL_BS is set to 6, the remaining 4
# requests will stay in the waiting queue.
with (
envs.SGLANG_TEST_RETRACT.override(True),
envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.override(6),
):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-running-requests",
16,
"--schedule-policy",
"random",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def _run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 4000,
"ignore_eos": True,
},
},
)
return response.json()
def test_abort_all_with_retraction(self):
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
# ensure the decode has been started and retractions happen.
time.sleep(8)
requests.post(
self.base_url + "/abort_request",
json={
"abort_all": True,
},
)
abort_in_queue_count = 0
abort_in_queue_with_none_empty_text = 0
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
if (
future.result()["meta_info"]["finish_reason"]["message"]
== "Abort in waiting queue"
):
abort_in_queue_count += 1
if len(future.result()["output_ids"]) > 0:
abort_in_queue_with_none_empty_text += 1
assert abort_in_queue_count > 0
assert abort_in_queue_with_none_empty_text > 0
print("Finished test_abort_all_with_retraction")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment