Commit e036115e authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_FLASH_ATTN_BACKEND to VLLM_USE_FLASH_ATTN_PA

parent e0d49ac2
...@@ -968,7 +968,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -968,7 +968,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
) )
else: else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_BACKEND: if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache from flash_attn import vllm_flash_attn_with_kvcache
# output[num_prefill_tokens:] = self.fa_decode_attn_func( # output[num_prefill_tokens:] = self.fa_decode_attn_func(
output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache( output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache(
......
...@@ -64,7 +64,7 @@ class PagedAttention: ...@@ -64,7 +64,7 @@ class PagedAttention:
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x] Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size] value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
''' '''
if envs.VLLM_USE_FLASH_ATTN_BACKEND: if envs.VLLM_USE_FLASH_ATTN_PA:
key_cache = kv_cache[0] key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1] value_cache = kv_cache[1]
...@@ -88,7 +88,7 @@ class PagedAttention: ...@@ -88,7 +88,7 @@ class PagedAttention:
k_scale: torch.Tensor, k_scale: torch.Tensor,
v_scale: torch.Tensor, v_scale: torch.Tensor,
) -> None: ) -> None:
if envs.VLLM_USE_FLASH_ATTN_BACKEND: if envs.VLLM_USE_FLASH_ATTN_PA:
ops.reshape_and_cache_cuda( ops.reshape_and_cache_cuda(
key, key,
value, value,
......
...@@ -150,7 +150,7 @@ if TYPE_CHECKING: ...@@ -150,7 +150,7 @@ if TYPE_CHECKING:
VLLM_TBO_DECODE_BS: int = 0 VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_BACKEND: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -992,9 +992,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -992,9 +992,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_MOE_FUSED_GATE": "VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# vLLM will use FlashAttention Backend for attention computation on rocm # vLLM will use FlashAttention Backend for page attention computation on rocm
"VLLM_USE_FLASH_ATTN_BACKEND": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_BACKEND", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
("true", "1")), ("true", "1")),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment