Unverified Commit bf37812c authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[Hybrid] Fix and optimize block-aligned splitting in mamba cache align mode (#33706)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
parent b86bf441
...@@ -281,27 +281,30 @@ class Scheduler(SchedulerInterface): ...@@ -281,27 +281,30 @@ class Scheduler(SchedulerInterface):
assert num_external_computed_tokens == 0, ( assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet" "External KV connector is not verified yet"
) )
# TODO: need check for resume requests num_computed_tokens = (
if request.num_output_tokens == 0: # prefill request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
# Perform block-aligned splitting at prefill phase, including:
# * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0
# * resumed requests: num_computed_tokens < (
# num_prompt_tokens + num_output_tokens
# )
# NOTE: Use `request.num_tokens - 1` to bypass normal decoding.
if num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1):
# To enable block-aligned caching of the Mamba state, `num_new_tokens` # To enable block-aligned caching of the Mamba state, `num_new_tokens`
# must be a multiple of `block_size`. # must be a multiple of `block_size`.
# As an exception, if `num_new_tokens` is less than `block_size`, the # As an exception, if `num_new_tokens` is less than `block_size`, the
# state is simply not cached, requiring no special handling. # state is simply not cached, requiring no special handling.
# Additionally, when Eagle mode is enabled, FullAttn prunes the last # Additionally, when Eagle mode is enabled, FullAttn prunes the last
# matching block. To prevent this from causing a Mamba cache miss, the # matching block. To prevent this from causing a Mamba cache miss, the
# last chunk must be larger than `block_size`. # last chunk must be not smaller than `block_size`.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
last_cache_position = ( last_cache_position = request.num_tokens - request.num_tokens % block_size
request.num_prompt_tokens - request.num_prompt_tokens % block_size
)
# eagle prune # eagle prune
if self.use_eagle: if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0) last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
if num_computed_tokens_after_sched < last_cache_position: if num_computed_tokens_after_sched < last_cache_position:
# align to block_size # align to block_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment