Commit b61f7a69 authored by zhuwenwen's avatar zhuwenwen
Browse files

set default block_size and pa

parent 1092a467
...@@ -75,7 +75,7 @@ class Attention(nn.Module): ...@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 block_size = 64
is_attention_free = False is_attention_free = False
calculate_kv_scales = False calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
...@@ -298,7 +298,7 @@ class MultiHeadAttention(nn.Module): ...@@ -298,7 +298,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size, attn_backend = get_attn_backend(head_size,
dtype, dtype,
kv_cache_dtype=None, kv_cache_dtype=None,
block_size=16, block_size=64,
is_attention_free=False) is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name()) backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
......
...@@ -995,7 +995,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -995,7 +995,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use FlashAttention Backend for page attention computation on rocm # vLLM will use FlashAttention Backend for page attention computation on rocm
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment