Commit 07b41ddf authored by zhuwenwen's avatar zhuwenwen
Browse files

using triton prefix cache on BW

parent 290a0d15
......@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC
from vllm.utils import SUPPORT_TC, gpuname
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
......@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
......@@ -854,7 +854,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# not applicable for encoder-only models
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or gpuname.startswith('BW'):
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment