Unverified Commit c786e757 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Use FA3 for MLA on Hopper (#12807)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent cefd56ee
......@@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping_start_idx,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
is_block_tables_empty)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
......@@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl):
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward(
self,
layer: AttentionLayer,
......@@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
else:
# prefix-enabled attention
......@@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
if decode_meta := attn_metadata.decode_metadata:
......@@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
......@@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output
......
......@@ -12,6 +12,7 @@ from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
......
......@@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np
import torch
from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.logger import logging
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
......@@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except ImportError:
VLLM_FLASH_ATTN_VERSION = None
......@@ -10,13 +10,10 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
from vllm.vllm_flash_attn import flash_attn_varlen_func
logger = init_logger(__name__)
......@@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward(
self,
layer: torch.nn.Module,
......@@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output
......@@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment