Unverified Commit e8ee2a78 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[Attention] use diff kv backend for mimo v2 flash (#40045)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 2ec18f5d
...@@ -172,7 +172,7 @@ Priority is **1 = highest** (tried first). ...@@ -172,7 +172,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
......
...@@ -634,9 +634,10 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -634,9 +634,10 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
except Exception: except Exception:
return {} return {}
# Analyze the functions to determine FA3-specific features # Analyze the functions to determine FA3/FA4-specific features
fa3_supports_fp8 = False fa3_supports_fp8 = False
fa3_supports_sinks = False fa3_supports_sinks = False
fa4_supports_sinks = False
fa3_compute_cap: str | None = None fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None fa4_compute_cap: str | None = None
...@@ -656,17 +657,49 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -656,17 +657,49 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = True fa3_supports_fp8 = True
break break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3` # Check flash_attn_supports_sinks - looks for `fa_version == 3/4`
# or `get_flash_attn_version() == 3/4` (also accepts `in (3, 4)`)
if node.name == "flash_attn_supports_sinks": if node.name == "flash_attn_supports_sinks":
for n in ast.walk(node): for n in ast.walk(node):
if ( if (
isinstance(n, ast.Compare) isinstance(n, ast.Compare)
and isinstance(n.left, ast.Call) and len(n.ops) == 1
and isinstance(n.left.func, ast.Name) and isinstance(n.ops[0], ast.Eq)
and n.left.func.id == "get_flash_attn_version" and isinstance(n.comparators[0], ast.Constant)
): ):
fa3_supports_sinks = True is_version_compare = (
break isinstance(n.left, ast.Name) and n.left.id == "fa_version"
) or (
isinstance(n.left, ast.Call)
and isinstance(n.left.func, ast.Name)
and n.left.func.id == "get_flash_attn_version"
)
if is_version_compare:
val = n.comparators[0].value
if val == 3:
fa3_supports_sinks = True
elif val == 4:
fa4_supports_sinks = True
elif (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.In)
and isinstance(n.comparators[0], (ast.Tuple, ast.List, ast.Set))
):
is_version_compare = (
isinstance(n.left, ast.Name) and n.left.id == "fa_version"
) or (
isinstance(n.left, ast.Call)
and isinstance(n.left.func, ast.Name)
and n.left.func.id == "get_flash_attn_version"
)
if is_version_compare:
for elt in n.comparators[0].elts:
if isinstance(elt, ast.Constant):
if elt.value == 3:
fa3_supports_sinks = True
elif elt.value == 4:
fa4_supports_sinks = True
# Check get_flash_attn_version for FA3/FA4 compute capability # Check get_flash_attn_version for FA3/FA4 compute capability
if node.name == "get_flash_attn_version": if node.name == "get_flash_attn_version":
...@@ -731,7 +764,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -731,7 +764,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"fa4": { "fa4": {
"compute_capability": fa4_compute_cap, "compute_capability": fa4_compute_cap,
"supports_fp8": False, "supports_fp8": False,
"supports_sink": False, "supports_sink": fa4_supports_sinks,
}, },
} }
......
...@@ -597,6 +597,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -597,6 +597,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size=block_size, block_size=block_size,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode, kv_quant_mode=quant_mode,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
......
...@@ -46,6 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,6 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.flash_attn_diffkv import (
FlashAttentionDiffKVBackend,
)
from .interfaces import MixtureOfExperts, SupportsPP from .interfaces import MixtureOfExperts, SupportsPP
from .utils import ( from .utils import (
...@@ -287,6 +290,15 @@ class MiMoV2Attention(nn.Module): ...@@ -287,6 +290,15 @@ class MiMoV2Attention(nn.Module):
) )
sliding_window = sliding_window_size if sliding_window_size > -1 else None sliding_window = sliding_window_size if sliding_window_size > -1 else None
# Use DiffKV backend when V has a different head dim than K
if self.v_head_dim != self.head_dim:
FlashAttentionDiffKVBackend.set_head_size_v(self.v_head_dim)
attn_backend = FlashAttentionDiffKVBackend
logger.info_once("Using FlashAttentionDiffKVBackend for attention.")
else:
attn_backend = None
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -298,6 +310,8 @@ class MiMoV2Attention(nn.Module): ...@@ -298,6 +310,8 @@ class MiMoV2Attention(nn.Module):
attn_type=AttentionType.DECODER, attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
sinks=self.attention_sink_bias, sinks=self.attention_sink_bias,
attn_backend=attn_backend,
head_size_v=self.v_head_dim,
) )
def forward( def forward(
...@@ -313,16 +327,8 @@ class MiMoV2Attention(nn.Module): ...@@ -313,16 +327,8 @@ class MiMoV2Attention(nn.Module):
if self.v_scale is not None: if self.v_scale is not None:
v = v * self.v_scale v = v * self.v_scale
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
v = v.view(-1, self.num_kv_heads * self.head_dim)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)[
..., : self.v_head_dim
].reshape(-1, self.num_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -54,7 +54,10 @@ elif current_platform.is_rocm(): ...@@ -54,7 +54,10 @@ elif current_platform.is_rocm():
def get_flash_attn_version( def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None requires_alibi: bool = False,
head_size: int | None = None,
head_size_v: int | None = None,
has_sinks: bool = False,
) -> int | None: ) -> int | None:
if current_platform.is_xpu(): if current_platform.is_xpu():
return 2 return 2
...@@ -112,6 +115,23 @@ def get_flash_attn_version( ...@@ -112,6 +115,23 @@ def get_flash_attn_version(
) )
fa_version = 2 fa_version = 2
# The FA3 kernel rejects s_aux (sinks) when hdim != hdim_v; upgrade to
# FA4 on SM90 when available.
if (
fa_version == 3
and has_sinks
and head_size is not None
and head_size_v is not None
and head_size != head_size_v
and device_capability.major == 9
and is_fa_version_supported(4)
):
logger.info_once(
"Diff-KV with sinks: upgrading FlashAttention 3 -> 4",
scope="local",
)
fa_version = 4
# FA4 currently uses batch-shape-dependent scheduling # FA4 currently uses batch-shape-dependent scheduling
# heuristics on SM100+, which breaks batch invariance. # heuristics on SM100+, which breaks batch invariance.
if envs.VLLM_BATCH_INVARIANT and fa_version == 4: if envs.VLLM_BATCH_INVARIANT and fa_version == 4:
...@@ -180,8 +200,7 @@ def flash_attn_supports_quant_query_input() -> bool: ...@@ -180,8 +200,7 @@ def flash_attn_supports_quant_query_input() -> bool:
def flash_attn_supports_sinks() -> bool: def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu(): if current_platform.is_xpu():
return True return True
else: return get_flash_attn_version() in (3, 4)
return get_flash_attn_version() == 3
def flash_attn_supports_mla(): def flash_attn_supports_mla():
......
...@@ -6,14 +6,16 @@ import torch ...@@ -6,14 +6,16 @@ import torch
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv, triton_reshape_and_cache_flash_diffkv,
) )
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from .flash_attn import ( from .flash_attn import (
...@@ -23,8 +25,6 @@ from .flash_attn import ( ...@@ -23,8 +25,6 @@ from .flash_attn import (
cascade_attention, cascade_attention,
) )
logger = init_logger(__name__)
class FlashAttentionDiffKVBackend(FlashAttentionBackend): class FlashAttentionDiffKVBackend(FlashAttentionBackend):
# Default to 128 for this backend # Default to 128 for this backend
...@@ -86,6 +86,20 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend): ...@@ -86,6 +86,20 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend):
class FlashAttentionDiffKVImpl(FlashAttentionImpl): class FlashAttentionDiffKVImpl(FlashAttentionImpl):
vllm_flash_attn_version: int | None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Re-derive the FA version with diff-kv context so that
# get_flash_attn_version can apply the FA3 -> FA4 upgrade rule
# for sinks + hdim != hdim_v.
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None,
head_size=self.head_size,
head_size_v=FlashAttentionDiffKVBackend.head_size_v,
has_sinks=self.sinks is not None,
)
def do_kv_cache_update( def do_kv_cache_update(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -356,6 +356,20 @@ class ChunkedLocalAttentionSpec(AttentionSpec): ...@@ -356,6 +356,20 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class SlidingWindowSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec):
sliding_window: int sliding_window: int
head_size_v: int = None # type: ignore[assignment]
def __post_init__(self):
if self.head_size_v is None:
object.__setattr__(self, "head_size_v", self.head_size)
@property
def real_page_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(self.dtype)
)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
......
...@@ -387,6 +387,7 @@ def flash_attn_varlen_func( ...@@ -387,6 +387,7 @@ def flash_attn_varlen_func(
num_splits=num_splits, num_splits=num_splits,
return_lse=return_softmax_lse, return_lse=return_softmax_lse,
out=out, out=out,
learnable_sink=s_aux,
) )
else: else:
raise ValueError(f"Unsupported FA version: {fa_version}") raise ValueError(f"Unsupported FA version: {fa_version}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment