"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "cda10fa3e2bb69ea276d663e5369ba16ec42cebb"
Unverified Commit ea53ca5e authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix block size used in EAGLE slot mapping (#31540)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 27864a85
...@@ -71,7 +71,6 @@ class EagleProposer: ...@@ -71,7 +71,6 @@ class EagleProposer:
self.device = device self.device = device
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
...@@ -470,22 +469,23 @@ class EagleProposer: ...@@ -470,22 +469,23 @@ class EagleProposer:
common_attn_metadata._num_computed_tokens_cpu += 1 common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping. # Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size
if self.uses_mrope: if self.uses_mrope:
# all dimensions of positions are the same # all dimensions of positions are the same
block_numbers = clamped_positions[0] // self.block_size block_numbers = clamped_positions[0] // block_size
else: else:
block_numbers = clamped_positions // self.block_size block_numbers = clamped_positions // block_size
block_ids = common_attn_metadata.block_table_tensor.gather( block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1) dim=1, index=block_numbers.view(-1, 1)
) )
block_ids = block_ids.view(-1) block_ids = block_ids.view(-1)
if self.uses_mrope: if self.uses_mrope:
common_attn_metadata.slot_mapping = ( common_attn_metadata.slot_mapping = (
block_ids * self.block_size + clamped_positions[0] % self.block_size block_ids * block_size + clamped_positions[0] % block_size
) )
else: else:
common_attn_metadata.slot_mapping = ( common_attn_metadata.slot_mapping = (
block_ids * self.block_size + clamped_positions % self.block_size block_ids * block_size + clamped_positions % block_size
) )
# Mask out the slot mappings that exceed the max model length. # Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the # Otherwise, the KV cache will be inadvertently updated with the
...@@ -800,12 +800,11 @@ class EagleProposer: ...@@ -800,12 +800,11 @@ class EagleProposer:
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping. # Compute the slot mapping.
block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
query_positions = flattened_draft_positions[:, level : level + query_len] query_positions = flattened_draft_positions[:, level : level + query_len]
block_numbers = query_positions // self.block_size block_numbers = query_positions // block_size
block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
slot_mapping = ( slot_mapping = block_ids * block_size + query_positions % block_size
block_ids * self.block_size + query_positions % self.block_size
)
# Mask out the slot mappings that exceed the max model length. # Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the # Otherwise, the KV cache will be inadvertently updated with the
# padding tokens. # padding tokens.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment