Unverified Commit b0b4f716 authored by cctry's avatar cctry Committed by GitHub
Browse files

[Fix] memory leak by overlap + retract (#11981)


Co-authored-by: default avatarLiangsheng Yin <lsyincs@gmail.com>
parent 6c18addb
...@@ -114,7 +114,6 @@ class Envs: ...@@ -114,7 +114,6 @@ class Envs:
# Test & Debug # Test & Debug
SGLANG_IS_IN_CI = EnvBool(False) SGLANG_IS_IN_CI = EnvBool(False)
SGLANG_IS_IN_CI_AMD = EnvBool(False) SGLANG_IS_IN_CI_AMD = EnvBool(False)
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_SET_CPU_AFFINITY = EnvBool(False) SGLANG_SET_CPU_AFFINITY = EnvBool(False)
SGLANG_PROFILE_WITH_STACK = EnvBool(True) SGLANG_PROFILE_WITH_STACK = EnvBool(True)
SGLANG_RECORD_STEP_TIME = EnvBool(False) SGLANG_RECORD_STEP_TIME = EnvBool(False)
...@@ -128,6 +127,11 @@ class Envs: ...@@ -128,6 +127,11 @@ class Envs:
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
# Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
# Scheduler: new token ratio hyperparameters # Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7) SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14) SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
......
...@@ -885,7 +885,6 @@ class Req: ...@@ -885,7 +885,6 @@ class Req:
self.temp_input_top_logprobs_idx = None self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.is_chunked = 0 self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None self.mamba_pool_idx = None
self.already_computed = 0 self.already_computed = 0
...@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
new_estimate_ratio = ( new_estimate_ratio = (
total_decoded_tokens total_decoded_tokens
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs) + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
) / total_max_new_tokens ) / (
total_max_new_tokens + 1
) # avoid zero division
new_estimate_ratio = min(1.0, new_estimate_ratio) new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio, [] return retracted_reqs, new_estimate_ratio, []
...@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Only contain fields that will be used by process_batch_result # Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
req_pool_indices=self.req_pool_indices,
model_config=self.model_config, model_config=self.model_config,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
......
...@@ -569,7 +569,8 @@ class PrefillAdder: ...@@ -569,7 +569,8 @@ class PrefillAdder:
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
CLIP_MAX_NEW_TOKENS,
) )
# adjusting the input_tokens based on host_hit_length and page_size # adjusting the input_tokens based on host_hit_length and page_size
......
...@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback ...@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
...@@ -1017,6 +1018,9 @@ class Scheduler( ...@@ -1017,6 +1018,9 @@ class Scheduler(
self.launch_batch_sample_if_needed(batch_result) self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch self.last_batch = batch
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
self._check_runtime_mem_leak()
def recv_requests(self) -> List[Req]: def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
...@@ -1833,7 +1837,7 @@ class Scheduler( ...@@ -1833,7 +1837,7 @@ class Scheduler(
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
TEST_RETRACT and batch.batch_size() > 10 TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
): ):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode( retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
......
...@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin: ...@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
logprob_pt = 0 logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted: if self.enable_overlap and req.is_retracted and len(req.output_ids) > 0:
req_idx = batch.req_pool_indices[i]
seq_len = len(req.origin_input_ids) + len(req.output_ids)
pos = batch.req_to_token_pool.req_to_token[req_idx][
seq_len - 1 : seq_len
]
self.token_to_kv_pool_allocator.free(pos)
continue continue
if self.is_mixed_chunk and self.enable_overlap and req.finished(): if (
self.is_mixed_chunk
and self.enable_overlap
and (req.finished() or req.is_retracted)
):
# Free the one delayed token for the mixed decode batch # Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue continue
if req.is_retracted:
continue
if req.is_chunked <= 0: if req.is_chunked <= 0:
# req output_ids are set here # req output_ids are set here
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
...@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin: ...@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
# We should ignore using next_token_ids for spec decoding cases. # We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req: Req req: Req
if req.is_retracted:
continue
if self.enable_overlap and req.finished(): if self.enable_overlap and (req.finished() or req.is_retracted):
indices_to_free = None indices_to_free = None
if batch.spec_algorithm.is_eagle(): if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
...@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin: ...@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
self.token_to_kv_pool_allocator.free(indices_to_free) self.token_to_kv_pool_allocator.free(indices_to_free)
continue continue
if req.is_retracted:
continue
new_accepted_len = 1 new_accepted_len = 1
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
......
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
...@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin: ...@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
return memory_leak, token_msg return memory_leak, token_msg
def _check_runtime_mem_leak(self: Scheduler):
current_batch: ScheduleBatch = self.last_batch
if current_batch is None:
return
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
extend_size = 0
for i, req in enumerate(current_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)
if current_batch.forward_mode.is_decode():
if req.finished():
unreleased_len = 1
else:
unreleased_len = seq_len - prefix_len
else:
unreleased_len = fill_len - prefix_len
extend_size += unreleased_len
if (
current_batch.forward_mode.is_extend()
and self.running_batch is not None
and not self.running_batch.is_empty()
and self.running_batch.forward_mode.is_decode()
):
for i, req in enumerate(self.running_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)
if req.finished():
unreleased_len = 0
else:
unreleased_len = seq_len - prefix_len - 1
extend_size += unreleased_len
total_tokens = available_size + evictable_size + protected_size + extend_size
assert (
total_tokens == self.max_total_num_tokens
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
def _check_req_pool(self: Scheduler): def _check_req_pool(self: Scheduler):
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
req_total_size = ( req_total_size = (
......
...@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache): ...@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.protected_size_ = 0
# NOTE (csy): this is to determine if a cache has prefix matching feature. # NOTE (csy): this is to determine if a cache has prefix matching feature.
# Chunk cache always return True to indicate no prefix matching. # Chunk cache always return True to indicate no prefix matching.
# TODO (csy): Using a prefix cache trait to replace this # TODO (csy): Using a prefix cache trait to replace this
...@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache): ...@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.protected_size_ -= len(req.prefix_indices)
def cache_unfinished_req(self, req: Req, chunked=False): def cache_unfinished_req(self, req: Req, chunked=False):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids) req.req_pool_idx, : len(req.fill_ids)
] ]
self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
...@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache): ...@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
return 0 return 0
def protected_size(self):
return self.protected_size_
def pretty_print(self): def pretty_print(self):
return "" return ""
......
...@@ -112,7 +112,7 @@ suites = { ...@@ -112,7 +112,7 @@ suites = {
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_regex_constrained.py", 64), TestFile("test_regex_constrained.py", 64),
TestFile("test_request_queue_validation.py", 30), TestFile("test_request_queue_validation.py", 30),
TestFile("test_retract_decode.py", 54), TestFile("test_retract_decode.py", 90),
TestFile("test_score_api.py", 310), TestFile("test_score_api.py", 310),
TestFile("test_server_args.py", 1), TestFile("test_server_args.py", 1),
TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_skip_tokenizer_init.py", 117),
......
import os import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -16,13 +17,12 @@ from sglang.test.test_utils import ( ...@@ -16,13 +17,12 @@ from sglang.test.test_utils import (
class TestRetractDecode(CustomTestCase): class TestRetractDecode(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( with envs.SGLANG_TEST_RETRACT.override(True):
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH cls.process = popen_launch_server(
) cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase): ...@@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65) self.assertGreaterEqual(metrics["score"], 0.65)
time.sleep(1) # wait for mem check
assert self.process.poll() is None, "Server crashed during test"
class TestRetractDecodeChunkCache(CustomTestCase): class TestRetractDecodeChunkCache(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( with envs.SGLANG_TEST_RETRACT.override(True):
cls.model, cls.process = popen_launch_server(
cls.base_url, cls.model,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, cls.base_url,
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128], timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128],
)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
) )
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
time.sleep(1) # wait for mem check
assert self.process.poll() is None, "Server crashed during test"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment