Unverified Commit 402759d4 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] FlashAttn MLA (#14258)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 2c301ee2
...@@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder( ...@@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata, split_decodes_and_prefills(
decode_threshold=1)) common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None has_initial_states = None
if num_prefills > 0: if num_prefills > 0:
#[batch,] #[batch,]
...@@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder( ...@@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states=has_initial_states, has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
) )
return attn_metadata return attn_metadata
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, ClassVar, Optional
import torch import torch
...@@ -197,6 +197,8 @@ class XFormersAttentionMetadata: ...@@ -197,6 +197,8 @@ class XFormersAttentionMetadata:
class XFormersAttentionMetadataBuilder( class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]): AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
...@@ -212,9 +214,10 @@ class XFormersAttentionMetadataBuilder( ...@@ -212,9 +214,10 @@ class XFormersAttentionMetadataBuilder(
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch, return reorder_batch_to_split_decodes_and_prefills(
scheduler_output, input_batch,
decode_threshold=1) scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build( def build(
self, self,
...@@ -223,8 +226,9 @@ class XFormersAttentionMetadataBuilder( ...@@ -223,8 +226,9 @@ class XFormersAttentionMetadataBuilder(
fast_build: bool = False, fast_build: bool = False,
) -> XFormersAttentionMetadata: ) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata, split_decodes_and_prefills(
decode_threshold=1)) common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc q_start_loc = common_attn_metadata.query_start_loc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment