Commit 57979f97 authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev-unified' into 'v0.15.1-dev'

Add FA Unified Attention 2D

See merge request dcutoolkit/deeplearing/vllm!501
parents 29646389 ad517f95
......@@ -310,6 +310,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
......@@ -1946,6 +1947,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING":
lambda: (os.environ.get("VLLM_ENABLE_RAY_ASYNC_SCHEDULING", "False").lower() in
("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -426,7 +426,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
kernel_block_alignment_size = 64
if (
current_platform.is_device_capability_family(100)
and model_config.get_head_size() == 256
......
......@@ -12,6 +12,10 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
from flash_attn import (
varlen_fwd_unified,
)
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -982,6 +986,9 @@ def unified_attention(
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[
(
total_num_q_blocks,
......@@ -1038,6 +1045,30 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment