"vscode:/vscode.git/clone" did not exist on "b4a253fc20071d58fa0cd43ab90cc76bde71c0fa"
Commit 77a29416 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_FLASH_ATTN_V1 to VLLM_USE_FLASH_ATTN_PA

parent 687b0ad7
......@@ -103,7 +103,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64
block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
......@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64,
block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm():
......
......@@ -1630,7 +1630,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA or not envs.VLLM_FLASH_ATTN_V1 else 64 # type: ignore
block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64 # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
......
......@@ -159,7 +159,6 @@ if TYPE_CHECKING:
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_V1: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0
......@@ -1078,11 +1077,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for v1 attention computation on rocm
"VLLM_FLASH_ATTN_V1":
lambda: (os.environ.get("VLLM_FLASH_ATTN_V1", "False").lower() in
("true", "1")),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN":
......
......@@ -276,7 +276,7 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1
if envs.VLLM_FLASH_ATTN_V1 and block_size == 64:
if envs.is_set("VLLM_USE_FLASH_ATTN_PA") and envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment