Commit efa6bed2 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-unified_fix' into 'v0.15.1-dev'

add fa unified attn 导入判断

See merge request dcutoolkit/deeplearing/vllm!509
parents 79052e70 3c900b76
...@@ -13,9 +13,10 @@ from vllm.logger import init_logger ...@@ -13,9 +13,10 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm import envs from vllm import envs
from flash_attn import ( try:
varlen_fwd_unified, from flash_attn import varlen_fwd_unified
) except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__) logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype()) float8_info = torch.finfo(current_platform.fp8_dtype())
...@@ -1045,6 +1046,10 @@ def unified_attention( ...@@ -1045,6 +1046,10 @@ def unified_attention(
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
) )
else: else:
if varlen_fwd_unified is None:
raise RuntimeError(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel") # print("Running FA kernel")
varlen_fwd_unified( varlen_fwd_unified(
q=q, q=q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment