"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "80b22ad881b6be61e49179940599614f47724553"
Unverified Commit 8af7048d authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Query remaining memory dynamically for PrefillAdder (#2941)

parent d3024f4f
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
...@@ -250,23 +251,24 @@ class PrefillAdder: ...@@ -250,23 +251,24 @@ class PrefillAdder:
def __init__( def __init__(
self, self,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool,
running_batch: ScheduleBatch, running_batch: ScheduleBatch,
new_token_ratio: float, new_token_ratio: float,
rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool
self.running_batch = running_batch self.running_batch = running_batch
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens self.rem_chunk_tokens -= mixed_with_decode_tokens
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_total_token_offset = mixed_with_decode_tokens
self.cur_rem_token_offset = mixed_with_decode_tokens
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
...@@ -275,8 +277,7 @@ class PrefillAdder: ...@@ -275,8 +277,7 @@ class PrefillAdder:
self.log_input_tokens = 0 self.log_input_tokens = 0
if running_batch is not None: if running_batch is not None:
# Pre-remove the tokens which will be occupied by the running requests self.rem_total_token_offset += sum(
self.rem_total_tokens -= sum(
[ [
min( min(
(r.sampling_params.max_new_tokens - len(r.output_ids)), (r.sampling_params.max_new_tokens - len(r.output_ids)),
...@@ -287,6 +288,22 @@ class PrefillAdder: ...@@ -287,6 +288,22 @@ class PrefillAdder:
] ]
) )
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
def budget_state(self): def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
...@@ -301,8 +318,8 @@ class PrefillAdder: ...@@ -301,8 +318,8 @@ class PrefillAdder:
def _prefill_one_req( def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
): ):
self.rem_total_tokens -= extend_input_len + max_new_tokens self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_tokens -= extend_input_len self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len self.rem_chunk_tokens -= extend_input_len
...@@ -332,12 +349,10 @@ class PrefillAdder: ...@@ -332,12 +349,10 @@ class PrefillAdder:
@contextmanager @contextmanager
def _lock_node(self, last_node: TreeNode): def _lock_node(self, last_node: TreeNode):
try: try:
delta = self.tree_cache.inc_lock_ref(last_node) self.tree_cache.inc_lock_ref(last_node)
self.rem_total_tokens += delta
yield None yield None
finally: finally:
delta = self.tree_cache.dec_lock_ref(last_node) self.tree_cache.dec_lock_ref(last_node)
self.rem_total_tokens += delta
def add_one_req_ignore_eos(self, req: Req): def add_one_req_ignore_eos(self, req: Req):
def add_req_state(r, insert_sort=False): def add_req_state(r, insert_sort=False):
......
...@@ -891,9 +891,9 @@ class Scheduler: ...@@ -891,9 +891,9 @@ class Scheduler:
# Prefill policy # Prefill policy
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.token_to_kv_pool,
self.running_batch, self.running_batch,
self.new_token_ratio, self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment