"src/runtime/cuda/cuda_hashtable.hip" did not exist on "bc3a532f5e7cd1d8e25cc1e30d1d086fa0648c31"
Unverified Commit f8e46093 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Fix prefill OOM error in the case of large page size (#5081)

parent 683707c3
...@@ -455,7 +455,10 @@ class PrefillAdder: ...@@ -455,7 +455,10 @@ class PrefillAdder:
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
) )
input_tokens = req.extend_input_len input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens: if total_tokens >= self.rem_total_tokens:
...@@ -477,7 +480,10 @@ class PrefillAdder: ...@@ -477,7 +480,10 @@ class PrefillAdder:
req.last_node_global, req.prefix_indices req.last_node_global, req.prefix_indices
) )
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
......
...@@ -502,6 +502,7 @@ class Scheduler( ...@@ -502,6 +502,7 @@ class Scheduler(
self.tree_cache = ChunkCache( self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
) )
else: else:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
......
...@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache): ...@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
def reset(self): def reset(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment