Unverified Commit f8e46093 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Fix prefill OOM error in the case of large page size (#5081)

parent 683707c3
...@@ -455,7 +455,10 @@ class PrefillAdder: ...@@ -455,7 +455,10 @@ class PrefillAdder:
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
) )
input_tokens = req.extend_input_len input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens: if total_tokens >= self.rem_total_tokens:
...@@ -477,7 +480,10 @@ class PrefillAdder: ...@@ -477,7 +480,10 @@ class PrefillAdder:
req.last_node_global, req.prefix_indices req.last_node_global, req.prefix_indices
) )
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
......
...@@ -502,6 +502,7 @@ class Scheduler( ...@@ -502,6 +502,7 @@ class Scheduler(
self.tree_cache = ChunkCache( self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
) )
else: else:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
......
...@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache): ...@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
def reset(self): def reset(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment