Unverified Commit caa5d296 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

feat: return partial generation results when aborting requests in waiting queue (#11673)

parent 750940ae
...@@ -86,6 +86,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -86,6 +86,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` | | `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` |
| `SGLANG_IS_IN_CI_AMD` | Indicates running in AMD CI environment | `0` | | `SGLANG_IS_IN_CI_AMD` | Indicates running in AMD CI environment | `0` |
| `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` | | `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` |
| `SGLANG_TEST_RETRACT_NO_PREFILL_BS` | When SGLANG_TEST_RETRACT is enabled, no prefill is performed if the batch size exceeds SGLANG_TEST_RETRACT_NO_PREFILL_BS. | `2 ** 31` |
| `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` | | `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` |
| `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` | | `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` |
| `SGLANG_CI_SMALL_KV_SIZE` | Use small KV cache size in CI | Not set | | `SGLANG_CI_SMALL_KV_SIZE` | Use small KV cache size in CI | Not set |
......
...@@ -135,6 +135,7 @@ class Envs: ...@@ -135,6 +135,7 @@ class Envs:
# Scheduler: memory leak test # Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False) SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3) SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False) SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
# Scheduler: new token ratio hyperparameters # Scheduler: new token ratio hyperparameters
......
...@@ -198,6 +198,7 @@ logger = logging.getLogger(__name__) ...@@ -198,6 +198,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get() TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get() TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
...@@ -1657,6 +1658,12 @@ class Scheduler( ...@@ -1657,6 +1658,12 @@ class Scheduler(
# Get priority queue # Get priority queue
self.policy.calc_priority(self.waiting_queue) self.policy.calc_priority(self.waiting_queue)
if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
# If we are testing retraction and the running batch size exceeds
# TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
# in the waiting queue.
return None
# Prefill policy # Prefill policy
adder = PrefillAdder( adder = PrefillAdder(
self.page_size, self.page_size,
......
...@@ -1487,6 +1487,51 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1487,6 +1487,51 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.crash_dump_folder and state.finished and state.obj.log_metrics: if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict) self.record_request_for_crash_dump(state, out_dict)
def add_logprob_to_meta_info(
self,
meta_info: dict,
state: ReqState,
top_logprobs_num: int,
token_ids_logprob: List[int],
return_text_in_logprobs: bool,
):
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
state.input_token_logprobs_val,
state.input_token_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
state.output_token_logprobs_val,
state.output_token_logprobs_idx,
return_text_in_logprobs,
)
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_top_logprobs_val,
state.input_top_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.output_top_logprobs_val,
state.output_top_logprobs_idx,
return_text_in_logprobs,
)
if token_ids_logprob is not None:
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_token_ids_logprobs_val,
state.input_token_ids_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens(
state.output_token_ids_logprobs_val,
state.output_token_ids_logprobs_idx,
return_text_in_logprobs,
)
)
def convert_logprob_style( def convert_logprob_style(
self, self,
meta_info: dict, meta_info: dict,
...@@ -1513,16 +1558,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1513,16 +1558,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_token_logprobs_idx.extend( state.output_token_logprobs_idx.extend(
recv_obj.output_token_logprobs_idx[recv_obj_index] recv_obj.output_token_logprobs_idx[recv_obj_index]
) )
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
state.input_token_logprobs_val,
state.input_token_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
state.output_token_logprobs_val,
state.output_token_logprobs_idx,
return_text_in_logprobs,
)
if top_logprobs_num > 0: if top_logprobs_num > 0:
if len(recv_obj.input_top_logprobs_val) > 0: if len(recv_obj.input_top_logprobs_val) > 0:
...@@ -1538,16 +1573,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1538,16 +1573,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_top_logprobs_idx.extend( state.output_top_logprobs_idx.extend(
recv_obj.output_top_logprobs_idx[recv_obj_index] recv_obj.output_top_logprobs_idx[recv_obj_index]
) )
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_top_logprobs_val,
state.input_top_logprobs_idx,
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
state.output_top_logprobs_val,
state.output_top_logprobs_idx,
return_text_in_logprobs,
)
if token_ids_logprob is not None: if token_ids_logprob is not None:
if len(recv_obj.input_token_ids_logprobs_val) > 0: if len(recv_obj.input_token_ids_logprobs_val) > 0:
...@@ -1563,18 +1588,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1563,18 +1588,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.output_token_ids_logprobs_idx.extend( state.output_token_ids_logprobs_idx.extend(
recv_obj.output_token_ids_logprobs_idx[recv_obj_index] recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
) )
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
state.input_token_ids_logprobs_val, self.add_logprob_to_meta_info(
state.input_token_ids_logprobs_idx, meta_info,
return_text_in_logprobs, state,
) state.obj.top_logprobs_num,
meta_info["output_token_ids_logprobs"] = ( state.obj.token_ids_logprob,
self.detokenize_top_logprobs_tokens(
state.output_token_ids_logprobs_val,
state.output_token_ids_logprobs_idx,
return_text_in_logprobs, return_text_in_logprobs,
) )
)
def detokenize_logprob_tokens( def detokenize_logprob_tokens(
self, self,
...@@ -1759,25 +1780,32 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1759,25 +1780,32 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return return
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
state.finished = True state.finished = True
if recv_obj.finished_reason:
out = { abort_message = recv_obj.abort_reason or "Abort in waiting queue"
"meta_info": { finish_reason = {
"id": recv_obj.rid, "type": "abort",
"finish_reason": recv_obj.finished_reason, "message": abort_message,
},
} }
else: if recv_obj.finished_reason:
finish_reason = recv_obj.finished_reason
meta_info = {"id": recv_obj.rid, "finish_reason": finish_reason}
is_stream = getattr(state.obj, "stream", False)
if getattr(state.obj, "return_logprob", False):
self.add_logprob_to_meta_info(
meta_info,
state,
state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs
and not self.server_args.skip_tokenizer_init,
)
output_ids = state.output_ids
meta_info["completion_tokens"] = len(output_ids)
out = { out = {
"text": "", "text": state.text,
"meta_info": { "output_ids": [output_ids[-1]] if is_stream else output_ids,
"id": recv_obj.rid, "meta_info": meta_info,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
} }
state.out_list.append(out) state.out_list.append(out)
state.event.set() state.event.set()
......
...@@ -71,7 +71,7 @@ suites = { ...@@ -71,7 +71,7 @@ suites = {
TestFile("rl/test_fp32_lm_head.py", 30), TestFile("rl/test_fp32_lm_head.py", 30),
TestFile("rl/test_update_weights_from_disk.py", 114), TestFile("rl/test_update_weights_from_disk.py", 114),
TestFile("rl/test_update_weights_from_tensor.py", 48), TestFile("rl/test_update_weights_from_tensor.py", 48),
TestFile("test_abort.py", 51), TestFile("test_abort.py", 121),
TestFile("test_build_eagle_tree.py", 8), TestFile("test_build_eagle_tree.py", 8),
TestFile("test_chunked_prefill.py", 313), TestFile("test_chunked_prefill.py", 313),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
......
import json import json
import multiprocessing import multiprocessing
import os
import time import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -110,5 +112,82 @@ class TestAbortAll(CustomTestCase): ...@@ -110,5 +112,82 @@ class TestAbortAll(CustomTestCase):
) )
class TestAbortAllWithRetraction(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
# Here's a small trick: in scheduler.py, when SGLANG_TEST_RETRACT is enabled,
# retraction is triggered when the batch size reaches 10.
# However, since SGLANG_TEST_RETRACT_NO_PREFILL_BS is set to 6, the remaining 4
# requests will stay in the waiting queue.
with (
envs.SGLANG_TEST_RETRACT.override(True),
envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.override(6),
):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-running-requests",
16,
"--schedule-policy",
"random",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def _run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 4000,
"ignore_eos": True,
},
},
)
return response.json()
def test_abort_all_with_retraction(self):
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
# ensure the decode has been started and retractions happen.
time.sleep(8)
requests.post(
self.base_url + "/abort_request",
json={
"abort_all": True,
},
)
abort_in_queue_count = 0
abort_in_queue_with_none_empty_text = 0
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
if (
future.result()["meta_info"]["finish_reason"]["message"]
== "Abort in waiting queue"
):
abort_in_queue_count += 1
if len(future.result()["output_ids"]) > 0:
abort_in_queue_with_none_empty_text += 1
assert abort_in_queue_count > 0
assert abort_in_queue_with_none_empty_text > 0
print("Finished test_abort_all_with_retraction")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment