Unverified Commit c555ce2c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Fix memory leak when doing chunked prefill" (#1797)

parent 40900bae
...@@ -15,7 +15,7 @@ class GlobalConfig: ...@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation # Runtime constants: New generation token ratio estimation
self.init_new_token_ratio = 0.7 self.init_new_token_ratio = 0.7
self.min_new_token_ratio = 0.1 self.base_min_new_token_ratio = 0.1
self.new_token_ratio_decay = 0.001 self.new_token_ratio_decay = 0.001
# Runtime constants: others # Runtime constants: others
...@@ -32,15 +32,5 @@ class GlobalConfig: ...@@ -32,15 +32,5 @@ class GlobalConfig:
self.enable_precache_with_tracing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
def adjust_new_token_ratio(self, schedule_conservativeness=1):
assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness"
min_new_token_ratio = min(
self.min_new_token_ratio * schedule_conservativeness,
1.0,
)
init_new_token_ratio = max(self.init_new_token_ratio, min_new_token_ratio)
return min_new_token_ratio, init_new_token_ratio
global_config = GlobalConfig() global_config = GlobalConfig()
...@@ -222,7 +222,7 @@ class Req: ...@@ -222,7 +222,7 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
self.is_being_chunked = False self.is_inflight_req = 0
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = False
...@@ -906,14 +906,15 @@ class ScheduleBatch: ...@@ -906,14 +906,15 @@ class ScheduleBatch:
def filter_batch( def filter_batch(
self, self,
being_chunked_req: Optional[Req] = None, current_inflight_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None, keep_indices: Optional[List[int]] = None,
): ):
if keep_indices is None: if keep_indices is None:
keep_indices = [ keep_indices = [
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
......
...@@ -136,7 +136,7 @@ class PrefillAdder: ...@@ -136,7 +136,7 @@ class PrefillAdder:
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_chunked_req = None self.new_inflight_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
self.log_input_tokens = 0 self.log_input_tokens = 0
...@@ -176,7 +176,7 @@ class PrefillAdder: ...@@ -176,7 +176,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_being_chunked_req(self, req: Req): def add_inflight_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -192,13 +192,8 @@ class PrefillAdder: ...@@ -192,13 +192,8 @@ class PrefillAdder:
), ),
) )
if truncated: # Return if chunked prefill not finished
# Continue to chunk the request return req if truncated else None
assert req.is_being_chunked
self.new_chunked_req = req
else:
# Release the being chunked status
req.is_being_chunked = False
@contextmanager @contextmanager
def _lock_node(self, last_node: TreeNode): def _lock_node(self, last_node: TreeNode):
...@@ -267,14 +262,11 @@ class PrefillAdder: ...@@ -267,14 +262,11 @@ class PrefillAdder:
) )
else: else:
# Chunked prefill # Chunked prefill
assert self.new_chunked_req is None
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.is_being_chunked = True
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_chunked_req = req self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0) self._prefill_one_req(0, trunc_len, 0)
return self.budget_state() return self.budget_state()
...@@ -313,18 +305,15 @@ class PrefillAdder: ...@@ -313,18 +305,15 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
) )
else: else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
if trunc_len == 0: if trunc_len == 0:
return AddReqResult.OTHER return AddReqResult.OTHER
# Chunked prefill
assert self.new_chunked_req is None
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
req.is_being_chunked = True
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_chunked_req = req self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
......
...@@ -219,12 +219,13 @@ class Scheduler: ...@@ -219,12 +219,13 @@ class Scheduler:
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
self.being_chunked_req = None self.current_inflight_req = None
self.is_mixed_chunk = ( self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
) )
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache( self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path, server_args.tokenizer_path,
{ {
...@@ -237,10 +238,16 @@ class Scheduler: ...@@ -237,10 +238,16 @@ class Scheduler:
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
# Init new token estimation # Init new token estimation
self.min_new_token_ratio, self.init_new_token_ratio = ( assert (
global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
) )
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False self.batch_is_full = False
# Init profiler # Init profiler
...@@ -287,7 +294,7 @@ class Scheduler: ...@@ -287,7 +294,7 @@ class Scheduler:
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
...@@ -314,7 +321,7 @@ class Scheduler: ...@@ -314,7 +321,7 @@ class Scheduler:
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch self.last_batch = batch
...@@ -492,18 +499,20 @@ class Scheduler: ...@@ -492,18 +499,20 @@ class Scheduler:
) )
exit(1) if crash_on_warning else None exit(1) if crash_on_warning else None
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
if ( if (
self.last_batch self.last_batch
and not self.last_batch.forward_mode.is_decode() and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.being_chunked_req: if self.current_inflight_req:
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) self.last_batch.filter_batch(
self.tree_cache.cache_unfinished_req(self.being_chunked_req) current_inflight_req=self.current_inflight_req
# Being chunked request keeps its rid but will get a new req_pool_idx. )
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.tree_cache.cache_unfinished_req(self.current_inflight_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
self.batch_is_full = False self.batch_is_full = False
if not self.last_batch.is_empty(): if not self.last_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
...@@ -534,7 +543,7 @@ class Scheduler: ...@@ -534,7 +543,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.batch_is_full or len(self.waiting_queue) == 0 self.batch_is_full or len(self.waiting_queue) == 0
) and self.being_chunked_req is None: ) and self.current_inflight_req is None:
return None return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0 running_bs = len(self.running_batch.reqs) if self.running_batch else 0
...@@ -557,6 +566,15 @@ class Scheduler: ...@@ -557,6 +566,15 @@ class Scheduler:
num_mixed_running, num_mixed_running,
) )
has_inflight = self.current_inflight_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = (
set([req.lora_path for req in self.running_batch.reqs]) set([req.lora_path for req in self.running_batch.reqs])
...@@ -564,13 +582,6 @@ class Scheduler: ...@@ -564,13 +582,6 @@ class Scheduler:
else set([]) else set([])
) )
# NOTE: if there is request being chunked, we always add it first
has_being_chunked = self.being_chunked_req is not None
if has_being_chunked:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self.being_chunked_req.init_next_round_input()
adder.add_being_chunked_req(self.being_chunked_req)
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if ( if (
...@@ -604,8 +615,12 @@ class Scheduler: ...@@ -604,8 +615,12 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
# Update new round being chunked request if adder.new_inflight_req is not None:
self.being_chunked_req = adder.new_chunked_req assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req
if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -634,7 +649,7 @@ class Scheduler: ...@@ -634,7 +649,7 @@ class Scheduler:
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" f"#queue-req: {len(self.waiting_queue) + has_inflight}"
) )
else: else:
logger.info( logger.info(
...@@ -645,7 +660,7 @@ class Scheduler: ...@@ -645,7 +660,7 @@ class Scheduler:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" f"#queue-req: {len(self.waiting_queue) + has_inflight}"
) )
# Create a new batch # Create a new batch
...@@ -694,7 +709,7 @@ class Scheduler: ...@@ -694,7 +709,7 @@ class Scheduler:
self.waiting_queue.extend(retracted_reqs) self.waiting_queue.extend(retracted_reqs)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - global_config.new_token_ratio_decay, self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio, self.min_new_token_ratio,
) )
...@@ -768,8 +783,10 @@ class Scheduler: ...@@ -768,8 +783,10 @@ class Scheduler:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if not req.is_being_chunked: if req.is_inflight_req > 0:
# Being chunked reqs' prefill is not finished req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i]) req.output_ids.append(next_token_ids[i])
req.check_finished() req.check_finished()
...@@ -795,8 +812,10 @@ class Scheduler: ...@@ -795,8 +812,10 @@ class Scheduler:
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i] req.embedding = embeddings[i]
if not req.is_being_chunked: if req.is_inflight_req > 0:
# Being chunked reqs' prefill is not finished req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models # dummy output token for embedding models
req.output_ids.append(0) req.output_ids.append(0)
req.check_finished() req.check_finished()
......
...@@ -663,7 +663,6 @@ def run_mmlu_test( ...@@ -663,7 +663,6 @@ def run_mmlu_test(
chunked_prefill_size=32, chunked_prefill_size=32,
): ):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
other_args += ["--mem-fraction-static", "0.85"]
if disable_radix_cache: if disable_radix_cache:
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
if enable_mixed_chunk: if enable_mixed_chunk:
......
import os
import random
import unittest
import requests
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
kill_child_process,
popen_launch_server,
)
def gen_radix_tree(num_nodes=400, chunk_len=256):
num0 = num_nodes // 2
num1 = num_nodes - num0
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
for _ in range(num0):
parent = random.choice(nodes)
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
while num1 > 0:
num_branch = random.randint(1, min(num1, 10))
parent = random.choice(nodes)
for _ in range(num_branch):
unique_len = random.randint(0, chunk_len)
decode_len = random.randint(0, chunk_len)
token_id = random.randint(0, 32000)
child = {
"input_ids": parent["input_ids"] + [token_id] * unique_len,
"decode_len": decode_len,
}
nodes.append(child)
num1 -= num_branch
random.shuffle(nodes)
return nodes
def run_test(base_url, nodes):
data = {
"input_ids": [node["input_ids"] for node in nodes],
"sampling_params": [
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
],
}
res = requests.post(base_url + "/generate", json=data)
assert res.status_code == 200
class TestRadixCacheFCFS(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"fcfs",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def test_radix_attention(self):
nodes = gen_radix_tree()
run_test(self.base_url, nodes)
class TestRadixCacheLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
"128",
"--max-total-tokens",
"20000",
"--schedule-policy",
"lpm",
],
)
if __name__ == "__main__":
os.environ["SGLANG_TEST_RETRACT"] = "true"
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment