Unverified Commit 31a719bc authored by Stig-Arne Grönroos's avatar Stig-Arne Grönroos Committed by GitHub
Browse files

[ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)


Signed-off-by: default avatarStig-Arne Grönroos <stig-arne.gronroos@amd.com>
Signed-off-by: default avatarStig-Arne Grönroos <sgronroo@amd.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 2e569756
...@@ -257,19 +257,19 @@ class DFlashProposer(SpecDecodeBaseProposer): ...@@ -257,19 +257,19 @@ class DFlashProposer(SpecDecodeBaseProposer):
) )
@override @override
def build_per_layer_attn_metadata( def build_per_group_and_layer_attn_metadata(
self, cad: CommonAttentionMetadata, draft_index: int = 0 self, cad: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]: ) -> tuple[list[object], dict[str, object]]:
per_layer_attention_metadata = super().build_per_layer_attn_metadata( per_group, per_layer = super().build_per_group_and_layer_attn_metadata(
cad, draft_index cad, draft_index
) )
for layer_name, attn_metadata in per_layer_attention_metadata.items(): for layer_name, attn_metadata in per_layer.items():
assert getattr(attn_metadata, "causal", None) is False, ( assert getattr(attn_metadata, "causal", None) is False, (
f"Attention metadata for layer {layer_name} does not have" f"Attention metadata for layer {layer_name} does not have"
" non-causal support, which is required for DFlash." " non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention." " Consider using a different attention backend, such as FlashAttention."
) )
return per_layer_attention_metadata return per_group, per_layer
@override @override
def _get_eagle3_use_aux_hidden_state_from_config(self): def _get_eagle3_use_aux_hidden_state_from_config(self):
......
...@@ -225,6 +225,9 @@ class SpecDecodeBaseProposer: ...@@ -225,6 +225,9 @@ class SpecDecodeBaseProposer:
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import ( from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
ROCMAiterMLASparseMetadata, ROCMAiterMLASparseMetadata,
) )
...@@ -234,6 +237,7 @@ class SpecDecodeBaseProposer: ...@@ -234,6 +237,7 @@ class SpecDecodeBaseProposer:
TritonAttentionMetadata, TritonAttentionMetadata,
RocmAttentionMetadata, RocmAttentionMetadata,
ROCMAiterMLASparseMetadata, ROCMAiterMLASparseMetadata,
DeepseekV32IndexerMetadata,
] ]
# ROCM_AITER_FA is an optional backend # ROCM_AITER_FA is an optional backend
# We check is_enabled() here to avoid importing the backend module during # We check is_enabled() here to avoid importing the backend module during
...@@ -444,8 +448,8 @@ class SpecDecodeBaseProposer: ...@@ -444,8 +448,8 @@ class SpecDecodeBaseProposer:
) )
) )
per_layer_attn_metadata = self.build_per_layer_attn_metadata( per_group_attn_metadata, per_layer_attn_metadata = (
common_attn_metadata self.build_per_group_and_layer_attn_metadata(common_attn_metadata)
) )
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
...@@ -486,10 +490,7 @@ class SpecDecodeBaseProposer: ...@@ -486,10 +490,7 @@ class SpecDecodeBaseProposer:
positions = self.positions[token_indices_to_sample] positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample]
if any( if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
isinstance(attn_metadata, TreeAttentionMetadata)
for attn_metadata in per_layer_attn_metadata.values()
):
# Draft using tree attention - requires full logits for top-k # Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree( draft_token_ids_list = self.propose_tree(
...@@ -505,16 +506,15 @@ class SpecDecodeBaseProposer: ...@@ -505,16 +506,15 @@ class SpecDecodeBaseProposer:
draft_token_ids = self._greedy_sample(sample_hidden_states) draft_token_ids = self._greedy_sample(sample_hidden_states)
for attn_metadata in per_layer_attn_metadata.values(): if self.allowed_attn_types is not None:
if self.allowed_attn_types is not None and not isinstance( for group_md in per_group_attn_metadata:
attn_metadata, self.allowed_attn_types if not isinstance(group_md, self.allowed_attn_types):
): raise ValueError(
raise ValueError( f"Unsupported attention metadata type for speculative "
f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: "
"decoding with num_speculative_tokens > 1: " f"{type(group_md)}. Supported types are: "
f"{type(attn_metadata)}. Supported types are: " f"{self.allowed_attn_types}"
f"{self.allowed_attn_types}" )
)
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
...@@ -595,7 +595,7 @@ class SpecDecodeBaseProposer: ...@@ -595,7 +595,7 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1 common_attn_metadata._num_computed_tokens_cpu += 1
# Rebuild attention metadata # Rebuild attention metadata
per_layer_attn_metadata = self.build_per_layer_attn_metadata( _, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
common_attn_metadata, draft_index=token_index + 1 common_attn_metadata, draft_index=token_index + 1
) )
...@@ -809,17 +809,19 @@ class SpecDecodeBaseProposer: ...@@ -809,17 +809,19 @@ class SpecDecodeBaseProposer:
return model_kwargs, num_input_tokens return model_kwargs, num_input_tokens
def build_per_layer_attn_metadata( def build_per_group_and_layer_attn_metadata(
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0 self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]: ) -> tuple[list[object], dict[str, object]]:
per_group_attn_metadata: list[object] = []
per_layer_attn_metadata: dict[str, object] = {} per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups: for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting( attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=draft_index common_attn_metadata=common_attn_metadata, draft_index=draft_index
) )
per_group_attn_metadata.append(attn_metadata)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
return per_layer_attn_metadata return per_group_attn_metadata, per_layer_attn_metadata
def model_returns_tuple(self) -> bool: def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model", "dflash") return self.method not in ("mtp", "draft_model", "dflash")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment