Commit 7cf5d5c4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface

parent 6b5ea53c
......@@ -215,7 +215,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
# from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
......@@ -1050,11 +1050,11 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# self.vllm_flash_attn_version = get_flash_attn_version()
# if self.vllm_flash_attn_version is not None:
# self.flash_attn_varlen_func = \
# functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment