Commit efa6bed2 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-unified_fix' into 'v0.15.1-dev'

add fa unified attn 导入判断

See merge request dcutoolkit/deeplearing/vllm!509
parents 79052e70 3c900b76
......@@ -13,9 +13,10 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm import envs
from flash_attn import (
varlen_fwd_unified,
)
try:
from flash_attn import varlen_fwd_unified
except Exception:
varlen_fwd_unified = None
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
......@@ -1045,6 +1046,10 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
if varlen_fwd_unified is None:
raise RuntimeError(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel")
varlen_fwd_unified(
q=q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment