"src/array/vscode:/vscode.git/clone" did not exist on "2f41fcd986a26661090e5eff6bafe212e354a690"
Unverified Commit 8af7048d authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Query remaining memory dynamically for PrefillAdder (#2941)

parent d3024f4f
......@@ -24,6 +24,7 @@ import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
......@@ -250,23 +251,24 @@ class PrefillAdder:
def __init__(
self,
tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_token_offset = mixed_with_decode_tokens
self.cur_rem_token_offset = mixed_with_decode_tokens
self.req_states = None
self.can_run_list = []
......@@ -275,8 +277,7 @@ class PrefillAdder:
self.log_input_tokens = 0
if running_batch is not None:
# Pre-remove the tokens which will be occupied by the running requests
self.rem_total_tokens -= sum(
self.rem_total_token_offset += sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
......@@ -287,6 +288,22 @@ class PrefillAdder:
]
)
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
......@@ -301,8 +318,8 @@ class PrefillAdder:
def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
self.rem_total_tokens -= extend_input_len + max_new_tokens
self.cur_rem_tokens -= extend_input_len
self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len
......@@ -332,12 +349,10 @@ class PrefillAdder:
@contextmanager
def _lock_node(self, last_node: TreeNode):
try:
delta = self.tree_cache.inc_lock_ref(last_node)
self.rem_total_tokens += delta
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
delta = self.tree_cache.dec_lock_ref(last_node)
self.rem_total_tokens += delta
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req):
def add_req_state(r, insert_sort=False):
......
......@@ -891,9 +891,9 @@ class Scheduler:
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool,
self.running_batch,
self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment