Unverified Commit 7a8a46dd authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[BugFix] Fix and optimize max_num_blocks_per_req calculation for MambaSpec (#34440)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
parent bcf0731a
...@@ -5698,28 +5698,23 @@ class GPUModelRunner( ...@@ -5698,28 +5698,23 @@ class GPUModelRunner(
kv_cache_config: The KV cache configuration. kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group. kernel_block_sizes: The kernel block sizes for each KV cache group.
""" """
block_sizes = [ block_sizes = []
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
max_num_blocks = [] max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len) max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for kv_cache_group in kv_cache_config.kv_cache_groups:
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue continue
block_size = kv_cache_group.kv_cache_spec.block_size
block_sizes.append(block_size)
max_num_blocks_per_req = cdiv( max_num_blocks_per_req = cdiv(
max_model_len, block_sizes[i] * get_total_cp_world_size() max_model_len, block_size * get_total_cp_world_size()
) )
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = ( max_num_blocks_per_req = (
max_num_blocks_per_req max_num_blocks_per_req
if self.cache_config.enable_prefix_caching if self.cache_config.enable_prefix_caching
else 1 else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks ) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks_per_req = max(
max_num_blocks_per_req, mamba_blocks_per_req
)
max_num_blocks.append(max_num_blocks_per_req) max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment