Commit 54e03934 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_lightop_moe_sum_mul_add' into 'v0.15.1-dev'

feat(v1 attention): 为 ROCm FlashAttention 接入 unified kv layout,并打通 mm_prefix、qq_bias 与 use_alibi_sqrt 透传

See merge request dcutoolkit/deeplearing/vllm!526
parents b81573da ee989f6d
......@@ -245,6 +245,7 @@ class Attention(nn.Module, AttentionLayerBase):
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
use_alibi_sqrt=bool(use_alibi_sqrt),
attn_type=attn_type,
)
else:
......
......@@ -319,7 +319,6 @@ if TYPE_CHECKING:
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False
USE_LIGHTOP_TOPK: bool = False
......
......@@ -262,6 +262,7 @@ class RocmPlatform(Platform):
from vllm._aiter_ops import rocm_aiter_ops
block_size = attn_selector_config.block_size
head_size = attn_selector_config.head_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
......@@ -304,9 +305,36 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend."
)
if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
is_non64_block_multiple_64 = (
block_size != 64
and block_size % 64 == 0
)
use_unified_flash = (
is_non64_block_multiple_64
and head_size == 256
)
if (
envs.VLLM_USE_FLASH_ATTN_PA
and is_non64_block_multiple_64
and head_size != 256
):
logger.info_once(
"Skip unified varlen kernel on V1 engine: head size %d is "
"unsupported (requires 256).",
head_size,
)
if envs.VLLM_USE_FLASH_ATTN_PA and (block_size == 64 or use_unified_flash):
if use_unified_flash and block_size != 64:
logger.info_once(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)",
block_size,
)
else:
logger.info_once(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
return AttentionBackendEnum.FLASH_ATTN.get_path()
else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -225,6 +225,7 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
use_alibi_sqrt: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
......@@ -241,6 +242,8 @@ class AttentionBackend(ABC):
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_alibi_sqrt and not cls.supports_alibi_sqrt():
invalid_reasons.append("use_alibi_sqrt not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
......
......@@ -33,6 +33,13 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
else:
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
......@@ -113,6 +120,38 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@staticmethod
def _use_rocm_unified_kv_layout(
block_size: int | None = None,
key_cache: torch.Tensor | None = None,
value_cache: torch.Tensor | None = None,
) -> bool:
if not current_platform.is_rocm():
return False
if block_size is None:
if key_cache is not None and value_cache is not None:
if key_cache.ndim != 4 or value_cache.ndim != 4:
return False
if key_cache.shape != value_cache.shape:
return False
block_size = key_cache.shape[1]
else:
try:
block_size = get_current_vllm_config().cache_config.block_size
except Exception:
return False
return block_size is not None and block_size != 64 and block_size % 64 == 0
if current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
......@@ -124,6 +163,9 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
return (unified_shape, unified_shape)
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
......@@ -136,6 +178,17 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if FlashAttentionBackend._use_rocm_unified_kv_layout():
if cache_layout != "NHD":
raise RuntimeError(
"ROCm unified KV layout currently supports NHD only."
)
if include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
else:
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
......@@ -271,8 +324,34 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
qq_bias: torch.Tensor | None = None
causal: bool = True
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)]
for i in range(num_seqs)
]
if all(r == [(0, 0)] for r in range_lists):
return None
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(
range_tensors, layout=torch.jagged
).to_padded_tensor(0)
def _get_sliding_window_configs(
vllm_config: VllmConfig,
......@@ -597,6 +676,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -622,6 +702,7 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.use_alibi_sqrt = use_alibi_sqrt
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
......@@ -648,6 +729,14 @@ class FlashAttentionImpl(AttentionImpl):
else False
)
def _get_unified_extras(
self,
attn_metadata: FlashAttentionMetadata,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
qq_bias = attn_metadata.qq_bias
return mm_prefix_range_tensor, qq_bias
def forward(
self,
layer: torch.nn.Module,
......@@ -774,6 +863,36 @@ class FlashAttentionImpl(AttentionImpl):
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
use_unified_kv_layout = (
FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache, value_cache=value_cache)
)
if use_unified_kv_layout:
mm_prefix_range_tensor, qq_bias = self._get_unified_extras(
attn_metadata
)
varlen_fwd_unified(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
softcap=self.logits_soft_cap,
window_size=tuple(self.sliding_window),
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=self.sinks,
mm_prefix_range=mm_prefix_range_tensor,
return_softmax_lse=False,
out=output[:num_actual_tokens],
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -889,8 +1008,24 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if current_platform.is_rocm():
if FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache,
value_cache=value_cache,
):
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
......
......@@ -12,11 +12,6 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -988,8 +983,6 @@ def unified_attention(
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
......@@ -1045,33 +1038,6 @@ def unified_attention(
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
if varlen_fwd_unified is None:
raise RuntimeError(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
......
......@@ -27,6 +27,7 @@ class AttentionSelectorConfig(NamedTuple):
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
use_alibi_sqrt: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
......@@ -39,6 +40,7 @@ class AttentionSelectorConfig(NamedTuple):
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_alibi_sqrt={self.use_alibi_sqrt}, "
f"attn_type={self.attn_type})"
)
......@@ -52,6 +54,7 @@ def get_attn_backend(
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_alibi_sqrt: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -77,6 +80,7 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_alibi_sqrt=use_alibi_sqrt,
attn_type=attn_type or AttentionType.DECODER,
)
......
......@@ -5952,7 +5952,7 @@ class GPUModelRunner(
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]
self, kv_caches: dict[str, Any]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
......@@ -5966,6 +5966,8 @@ class GPUModelRunner(
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if not isinstance(kv_cache, torch.Tensor):
continue
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment