Commit 77a29416 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_FLASH_ATTN_V1 to VLLM_USE_FLASH_ATTN_PA

parent 687b0ad7
...@@ -103,7 +103,7 @@ class Attention(nn.Module): ...@@ -103,7 +103,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64 block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64
is_attention_free = False is_attention_free = False
calculate_kv_scales = False calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
...@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module): ...@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size, attn_backend = get_attn_backend(head_size,
dtype, dtype,
kv_cache_dtype=None, kv_cache_dtype=None,
block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64, block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64,
is_attention_free=False) is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name()) backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -1630,7 +1630,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] ...@@ -1630,7 +1630,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64 # type: ignore block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64 # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on """Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128. sizes up to 32 are supported. On HPU devices, block size defaults to 128.
......
...@@ -159,7 +159,6 @@ if TYPE_CHECKING: ...@@ -159,7 +159,6 @@ if TYPE_CHECKING:
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_V1: bool = False
VLLM_USE_NN: bool = False VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0 VLLM_TBO_REQ_DELAY_MS: int = 0
...@@ -1078,11 +1077,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1078,11 +1077,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to restore the default usage. # to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT": "VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))), lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for v1 attention computation on rocm
"VLLM_FLASH_ATTN_V1":
lambda: (os.environ.get("VLLM_FLASH_ATTN_V1", "False").lower() in
("true", "1")),
# If set, vLLM will transpose weight to use nn layout # If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN": "VLLM_USE_NN":
......
...@@ -276,7 +276,7 @@ class RocmPlatform(Platform): ...@@ -276,7 +276,7 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.") # logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1 # return TRITON_ATTN_VLLM_V1
if envs.VLLM_FLASH_ATTN_V1 and block_size == 64: if envs.is_set("VLLM_USE_FLASH_ATTN_PA") and envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)") logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment