Unverified Commit 2b99f210 authored by Guofang.Tang's avatar Guofang.Tang Committed by GitHub
Browse files

[Misc] Fix typo: seperator -> separator in flashmla_sparse.py (#32411)


Signed-off-by: default avatarGuofang Tang <tinggofun@gmail.com>
Co-authored-by: default avatarGuofang Tang <tinggofun@gmail.com>
parent 1646fea6
...@@ -149,7 +149,7 @@ class FlashMLASparseMetadata(AttentionMetadata): ...@@ -149,7 +149,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
cache_lens: torch.Tensor cache_lens: torch.Tensor
@dataclass @dataclass
class FP8SeperatePrefillDecode: class FP8SeparatePrefillDecode:
@dataclass @dataclass
class Decode: class Decode:
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
...@@ -196,7 +196,7 @@ class FlashMLASparseMetadata(AttentionMetadata): ...@@ -196,7 +196,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
decode: Decode | None = None decode: Decode | None = None
prefill: Prefill | None = None prefill: Prefill | None = None
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False fp8_use_mixed_batch: bool = False
...@@ -485,7 +485,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -485,7 +485,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
def _build_fp8_separate_prefill_decode( def _build_fp8_separate_prefill_decode(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
...@@ -496,7 +496,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -496,7 +496,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
) )
) )
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
fp8_metadata = FP8Meta( fp8_metadata = FP8Meta(
num_decodes=num_decodes, num_decodes=num_decodes,
num_prefills=num_prefills, num_prefills=num_prefills,
...@@ -659,7 +659,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -659,7 +659,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
req_id_per_token = self.req_id_per_token_buffer[:num_tokens] req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata: ( fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeperatePrefillDecode FlashMLASparseMetadata.FP8SeparatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata | FlashMLASparseMetadata.FP8KernelMetadata
| None | None
) = None ) = None
...@@ -765,7 +765,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -765,7 +765,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
attn_metadata: FlashMLASparseMetadata, attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
fp8_metadata = attn_metadata.fp8_extra_metadata fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
num_decodes = fp8_metadata.num_decodes num_decodes = fp8_metadata.num_decodes
prefill_request_ids = None prefill_request_ids = None
...@@ -794,7 +794,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -794,7 +794,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
) )
fp8_metadata = attn_metadata.fp8_extra_metadata fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim) # Reshape q: (num_decode_tokens, num_heads, head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment