"vscode:/vscode.git/clone" did not exist on "0cd3d9717e38c7a122ed01fe2a8fddd8b37dff4b"
Commit eb35ba1b authored by fanwl's avatar fanwl
Browse files

Add FA Unified Attention 2D

- Add VLLM_V1_USE_FA_UNIFIED_ATTN_2D 环境变量
- 0: Triton attention, 1: FA unified attention
parent 3f414133
...@@ -310,6 +310,7 @@ if TYPE_CHECKING: ...@@ -310,6 +310,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1940,6 +1941,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1940,6 +1941,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK": "VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in
("true", "1")), ("true", "1")),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -426,7 +426,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -426,7 +426,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
dtype=kv_cache_dtype, dtype=kv_cache_dtype,
).page_size_bytes ).page_size_bytes
else: else:
kernel_block_alignment_size = 16 kernel_block_alignment_size = 64
if ( if (
current_platform.is_device_capability_family(100) current_platform.is_device_capability_family(100)
and model_config.get_head_size() == 256 and model_config.get_head_size() == 256
......
...@@ -12,6 +12,10 @@ import torch ...@@ -12,6 +12,10 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm import envs
from flash_attn import (
varlen_fwd_unified,
)
logger = init_logger(__name__) logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype()) float8_info = torch.finfo(current_platform.fp8_dtype())
...@@ -982,6 +986,9 @@ def unified_attention( ...@@ -982,6 +986,9 @@ def unified_attention(
or max_seqlen_q > 1 or max_seqlen_q > 1
or num_seqs > seq_threshold_3D or num_seqs > seq_threshold_3D
): ):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if not envs.VLLM_V1_USE_FA_UNIFIED_ATTN_2D:
# print("Running Triton kernel")
kernel_unified_attention_2d[ kernel_unified_attention_2d[
( (
total_num_q_blocks, total_num_q_blocks,
...@@ -1038,6 +1045,30 @@ def unified_attention( ...@@ -1038,6 +1045,30 @@ def unified_attention(
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
) )
else: else:
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k,
block_table=block_table,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
alibi_slopes=alibi_slopes,
use_alibi_sqrt=use_alibi_sqrt,
qq_bias=qq_bias,
s_aux=sinks,
mm_prefix_range=mm_prefix_range,
return_softmax_lse=False,
out=out,
)
else:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d[ kernel_unified_attention_3d[
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments) (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
]( ](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment