Commit b61f7a69 authored by zhuwenwen's avatar zhuwenwen
Browse files

set default block_size and pa

parent 1092a467
......@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
block_size = 64
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
......@@ -298,7 +298,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
block_size=64,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
......
......@@ -995,7 +995,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use FlashAttention Backend for page attention computation on rocm
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")),
# vLLM will use apex for rmsnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment