Commit 07b41ddf authored by zhuwenwen's avatar zhuwenwen
Browse files

using triton prefix cache on BW

parent 290a0d15
...@@ -22,7 +22,7 @@ from vllm.logger import init_logger ...@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC, gpuname
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
...@@ -854,7 +854,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -854,7 +854,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# not applicable for encoder-only models # not applicable for encoder-only models
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: # if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func # self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or gpuname.startswith('BW'):
version_key = triton_key() version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix( output[:num_prefill_tokens] = paged_attn.forward_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment