Unverified Commit c786e757 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Use FA3 for MLA on Hopper (#12807)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent cefd56ee
...@@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
from vllm.attention.backends.utils import ( from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping, compute_slot_mapping_start_idx,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
from vllm.envs import VLLM_FLASH_ATTN_VERSION is_block_tables_empty)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_varlen_func, flash_attn_with_kvcache)
flash_attn_with_kvcache,
is_fa_version_supported)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl):
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type self.attn_type = attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
...@@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
...@@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
...@@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output, out=decode_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
...@@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=decode_output.unsqueeze(1), out=decode_output.unsqueeze(1),
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output
......
...@@ -12,6 +12,7 @@ from vllm import envs ...@@ -12,6 +12,7 @@ from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl, T) MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len, max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
) )
attn_output = attn_output\ attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
......
...@@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union ...@@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.logger import logging
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
...@@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens( ...@@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except ImportError:
VLLM_FLASH_ATTN_VERSION = None
...@@ -10,13 +10,10 @@ import triton.language as tl ...@@ -10,13 +10,10 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.vllm_flash_attn import (fa_version_unsupported_reason, from vllm.vllm_flash_attn import flash_attn_varlen_func
flash_attn_varlen_func,
is_fa_version_supported)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashAttentionImpl") "FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output
...@@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment