Commit b233584a authored by laibao's avatar laibao Committed by zhangzbb
Browse files

[BUGFIX] 回退 ROCm FlashAttention unified KV layout 改动并修正 unified kernel 选择逻辑

parent 2888b4e5
......@@ -245,7 +245,6 @@ class Attention(nn.Module, AttentionLayerBase):
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
use_alibi_sqrt=bool(use_alibi_sqrt),
attn_type=attn_type,
)
else:
......@@ -1274,4 +1273,4 @@ direct_register_custom_op(
mutates_args=["qkv", "positions"],
fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
)
\ No newline at end of file
......@@ -1989,11 +1989,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING":
lambda: (os.environ.get("VLLM_ENABLE_RAY_ASYNC_SCHEDULING", "False").lower() in
("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8":
lambda: (os.environ.get("USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8", "False").lower() in
("true", "1")),
......
......@@ -262,7 +262,6 @@ class RocmPlatform(Platform):
from vllm._aiter_ops import rocm_aiter_ops
block_size = attn_selector_config.block_size
head_size = attn_selector_config.head_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
......@@ -305,36 +304,9 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend."
)
is_non64_block_multiple_64 = (
block_size != 64
and block_size % 64 == 0
)
use_unified_flash = (
is_non64_block_multiple_64
and head_size == 256
)
if (
envs.VLLM_USE_FLASH_ATTN_PA
and is_non64_block_multiple_64
and head_size != 256
):
logger.info_once(
"Skip unified varlen kernel on V1 engine: head size %d is "
"unsupported (requires 256).",
head_size,
)
if envs.VLLM_USE_FLASH_ATTN_PA and (block_size == 64 or use_unified_flash):
if use_unified_flash and block_size != 64:
logger.info_once(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)",
block_size,
)
else:
logger.info_once(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return AttentionBackendEnum.FLASH_ATTN.get_path()
else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
......
......@@ -225,7 +225,6 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
use_alibi_sqrt: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
......@@ -242,8 +241,6 @@ class AttentionBackend(ABC):
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_alibi_sqrt and not cls.supports_alibi_sqrt():
invalid_reasons.append("use_alibi_sqrt not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
......
......@@ -33,13 +33,6 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
else:
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
......@@ -119,38 +112,6 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@staticmethod
def _use_rocm_unified_kv_layout(
block_size: int | None = None,
key_cache: torch.Tensor | None = None,
value_cache: torch.Tensor | None = None,
) -> bool:
if not current_platform.is_rocm():
return False
if block_size is None:
if key_cache is not None and value_cache is not None:
if key_cache.ndim != 4 or value_cache.ndim != 4:
return False
if key_cache.shape != value_cache.shape:
return False
block_size = key_cache.shape[1]
else:
try:
block_size = get_current_vllm_config().cache_config.block_size
except Exception:
return False
return block_size is not None and block_size != 64 and block_size % 64 == 0
if current_platform.is_rocm():
@staticmethod
......@@ -163,9 +124,6 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
return (unified_shape, unified_shape)
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
......@@ -178,31 +136,20 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if FlashAttentionBackend._use_rocm_unified_kv_layout():
if cache_layout != "NHD":
raise RuntimeError(
"ROCm unified KV layout currently supports NHD only."
)
if include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
else:
@staticmethod
......@@ -324,34 +271,8 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
qq_bias: torch.Tensor | None = None
causal: bool = True
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)]
for i in range(num_seqs)
]
if all(r == [(0, 0)] for r in range_lists):
return None
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(
range_tensors, layout=torch.jagged
).to_padded_tensor(0)
def _get_sliding_window_configs(
vllm_config: VllmConfig,
......@@ -676,7 +597,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -702,7 +622,6 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.use_alibi_sqrt = use_alibi_sqrt
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
......@@ -729,14 +648,6 @@ class FlashAttentionImpl(AttentionImpl):
else False
)
def _get_unified_extras(
self,
attn_metadata: FlashAttentionMetadata,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
qq_bias = attn_metadata.qq_bias
return mm_prefix_range_tensor, qq_bias
def forward(
self,
layer: torch.nn.Module,
......@@ -863,60 +774,30 @@ class FlashAttentionImpl(AttentionImpl):
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
use_unified_kv_layout = (
FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache, value_cache=value_cache)
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
if use_unified_kv_layout:
mm_prefix_range_tensor, qq_bias = self._get_unified_extras(
attn_metadata
)
varlen_fwd_unified(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
softcap=self.logits_soft_cap,
window_size=tuple(self.sliding_window),
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=self.sinks,
mm_prefix_range=mm_prefix_range_tensor,
return_softmax_lse=False,
out=output[:num_actual_tokens],
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
......@@ -1008,11 +889,21 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if current_platform.is_rocm():
if FlashAttentionBackend._use_rocm_unified_kv_layout(
key_cache=key_cache,
value_cache=value_cache,
):
triton_reshape_and_cache_flash(
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
......@@ -1022,32 +913,6 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache_flash(
key,
......
......@@ -12,6 +12,10 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -983,61 +987,92 @@ def unified_attention(
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
use_fa_unified_2d = (
current_platform.is_rocm()
and varlen_fwd_unified is not None
and block_size % 64 == 0
and head_size == 256
)
if not use_fa_unified_2d:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
......
......@@ -27,7 +27,6 @@ class AttentionSelectorConfig(NamedTuple):
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
use_alibi_sqrt: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
......@@ -40,7 +39,6 @@ class AttentionSelectorConfig(NamedTuple):
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_alibi_sqrt={self.use_alibi_sqrt}, "
f"attn_type={self.attn_type})"
)
......@@ -54,7 +52,6 @@ def get_attn_backend(
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_alibi_sqrt: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -80,7 +77,6 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_alibi_sqrt=use_alibi_sqrt,
attn_type=attn_type or AttentionType.DECODER,
)
......
......@@ -5958,7 +5958,7 @@ class GPUModelRunner(
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, Any]
self, kv_caches: dict[str, torch.Tensor]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
......@@ -5972,8 +5972,6 @@ class GPUModelRunner(
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if not isinstance(kv_cache, torch.Tensor):
continue
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment