Unverified Commit 56e19d7e authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Model Runner V2] Fix flex attention kv blocks calculation issue (#39353)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 9036d4c4
...@@ -750,11 +750,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -750,11 +750,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
max_num_seqs = vllm_config.scheduler_config.max_num_seqs max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_q_block = ( self.max_num_query_groups = cdiv(max_num_batched_tokens, self.q_block_size)
self.max_model_len + self.q_block_size - 1 max_num_pages_per_seq = cdiv(self.max_model_len, self.block_size)
) // self.q_block_size self.max_num_kv_indices = self.q_block_size * max_num_pages_per_seq
self.persistent_kv_num_blocks = torch.empty( self.persistent_kv_num_blocks = torch.empty(
self.max_num_q_block, dtype=torch.int32, device=device self.max_num_query_groups, dtype=torch.int32, device=device
) )
self.persistent_offset_tensor = torch.empty( self.persistent_offset_tensor = torch.empty(
max_num_seqs, dtype=torch.int32, device=device max_num_seqs, dtype=torch.int32, device=device
...@@ -828,12 +828,9 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -828,12 +828,9 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
) )
if self.persistent_kv_indices is None: if self.persistent_kv_indices is None:
max_num_kv_block = (
self.max_model_len + self.kv_block_size - 1
) // self.kv_block_size
self.persistent_kv_indices = torch.empty( self.persistent_kv_indices = torch.empty(
self.max_model_len, self.max_num_query_groups,
max_num_kv_block, self.max_num_kv_indices,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment