Unverified Commit b7a26050 authored by Srreyansh Sethi's avatar Srreyansh Sethi Committed by GitHub
Browse files

[Bugfix] Make Attention Backend Auto-Selection Batch-Invariance-Aware (#40193)


Signed-off-by: default avatarSrreyansh Sethi <srreyansh.sethi@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent d0009ddb
...@@ -131,16 +131,9 @@ class TrainModel: ...@@ -131,16 +131,9 @@ class TrainModel:
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance, init_batch_invariance,
) )
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops # need to init all env vars for batch invariance which affect nccl ops
attn_backend = ( init_batch_invariance()
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16 model_name, dtype=torch.bfloat16
......
...@@ -14,7 +14,6 @@ from vllm.triton_utils import tl, triton ...@@ -14,7 +14,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils.mem_utils import get_max_shared_memory_bytes from vllm.utils.mem_utils import get_max_shared_memory_bytes
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -991,40 +990,7 @@ def enable_batch_invariant_mode(): ...@@ -991,40 +990,7 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt") torch.backends.cuda.preferred_blas_library(backend="cublaslt")
def override_envs_for_invariance( def override_envs_for_invariance():
attention_backend: AttentionBackendEnum | None,
):
decode_invariant_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.TRITON_ATTN,
]
supported_backends = decode_invariant_backends + [
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends
# AttentionBackendEnum.FLASHMLA,
# AttentionBackendEnum.FLEX_ATTENTION, # IMA issue
# AttentionBackendEnum.FLASHINFER_MLA, # PR #28967
]
if attention_backend not in supported_backends:
supported_names = [b.name for b in supported_backends]
backend_name = attention_backend.name if attention_backend else None
error = (
"VLLM batch_invariant mode requires an attention backend in "
f"{supported_names}, but got '{backend_name}'. "
"Please use --attention-backend or attention_config to set "
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if attention_backend not in decode_invariant_backends:
warning = (
"You are using a non-decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning)
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
...@@ -1045,12 +1011,10 @@ def override_envs_for_invariance( ...@@ -1045,12 +1011,10 @@ def override_envs_for_invariance(
os.environ["VLLM_USE_AOT_COMPILE"] = "0" os.environ["VLLM_USE_AOT_COMPILE"] = "0"
def init_batch_invariance( def init_batch_invariance():
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well # this will hit all the csrc overrides as well
if envs.VLLM_BATCH_INVARIANT: if envs.VLLM_BATCH_INVARIANT:
override_envs_for_invariance(attention_backend) override_envs_for_invariance()
enable_batch_invariant_mode() enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding # Disable TF32 for batch invariance - it causes non-deterministic rounding
......
...@@ -236,6 +236,10 @@ class AttentionBackend(ABC): ...@@ -236,6 +236,10 @@ class AttentionBackend(ABC):
""" """
return False return False
@classmethod
def supports_batch_invariance(cls) -> bool:
return False
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type. """Check if backend supports a given attention type.
...@@ -278,6 +282,7 @@ class AttentionBackend(ABC): ...@@ -278,6 +282,7 @@ class AttentionBackend(ABC):
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str, attn_type: str,
use_non_causal: bool = False, use_non_causal: bool = False,
use_batch_invariant: bool = False,
) -> list[str]: ) -> list[str]:
invalid_reasons = [] invalid_reasons = []
if not cls.supports_head_size(head_size): if not cls.supports_head_size(head_size):
...@@ -312,6 +317,8 @@ class AttentionBackend(ABC): ...@@ -312,6 +317,8 @@ class AttentionBackend(ABC):
invalid_reasons.append(f"attention type {attn_type} not supported") invalid_reasons.append(f"attention type {attn_type} not supported")
if use_non_causal and not cls.supports_non_causal(): if use_non_causal and not cls.supports_non_causal():
invalid_reasons.append("non-causal attention not supported") invalid_reasons.append("non-causal attention not supported")
if use_batch_invariant and not cls.supports_batch_invariance():
invalid_reasons.append("batch invariance not supported")
combination_reason = cls.supports_combination( combination_reason = cls.supports_combination(
head_size, head_size,
dtype, dtype,
......
...@@ -103,6 +103,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -103,6 +103,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@classmethod
def supports_batch_invariance(cls) -> bool:
return True
@classmethod @classmethod
def supports_non_causal(cls) -> bool: def supports_non_causal(cls) -> bool:
return True return True
......
...@@ -56,6 +56,10 @@ class FlashAttnMLABackend(MLACommonBackend): ...@@ -56,6 +56,10 @@ class FlashAttnMLABackend(MLACommonBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_MLA" return "FLASH_ATTN_MLA"
@classmethod
def supports_batch_invariance(cls) -> bool:
return True
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder return FlashAttnMLAMetadataBuilder
......
...@@ -55,6 +55,10 @@ class TritonMLABackend(MLACommonBackend): ...@@ -55,6 +55,10 @@ class TritonMLABackend(MLACommonBackend):
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"
@classmethod
def supports_batch_invariance(cls) -> bool:
return True
@staticmethod @staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]: def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl return TritonMLAImpl
......
...@@ -296,6 +296,10 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -296,6 +296,10 @@ class TritonAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "TRITON_ATTN" return "TRITON_ATTN"
@classmethod
def supports_batch_invariance(cls) -> bool:
return True
@staticmethod @staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]: def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl return TritonAttentionImpl
......
...@@ -6,6 +6,7 @@ from typing import NamedTuple, cast, get_args ...@@ -6,6 +6,7 @@ from typing import NamedTuple, cast, get_args
import torch import torch
import vllm.envs as envs
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
...@@ -30,6 +31,7 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -30,6 +31,7 @@ class AttentionSelectorConfig(NamedTuple):
use_per_head_quant_scales: bool = False use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER attn_type: str = AttentionType.DECODER
use_non_causal: bool = False use_non_causal: bool = False
use_batch_invariant: bool = False
def __repr__(self): def __repr__(self):
return ( return (
...@@ -43,7 +45,8 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -43,7 +45,8 @@ class AttentionSelectorConfig(NamedTuple):
f"use_mm_prefix={self.use_mm_prefix}, " f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, " f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type}, " f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal})" f"use_non_causal={self.use_non_causal}, "
f"use_batch_invariant={self.use_batch_invariant})"
) )
...@@ -95,6 +98,7 @@ def get_attn_backend( ...@@ -95,6 +98,7 @@ def get_attn_backend(
use_per_head_quant_scales=use_per_head_quant_scales, use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER, attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal, use_non_causal=use_non_causal,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
) )
return _cached_get_attn_backend( return _cached_get_attn_backend(
...@@ -162,4 +166,9 @@ def _cached_get_mamba_attn_backend( ...@@ -162,4 +166,9 @@ def _cached_get_mamba_attn_backend(
) from e ) from e
mamba_attn_backend = selected_backend.get_class() mamba_attn_backend = selected_backend.get_class()
if envs.VLLM_BATCH_INVARIANT and not mamba_attn_backend.supports_batch_invariance():
raise RuntimeError(
"VLLM batch_invariant mode is not supported for "
f"{mamba_attn_backend.get_name()}."
)
return mamba_attn_backend return mamba_attn_backend
...@@ -1027,11 +1027,10 @@ def init_worker_distributed_environment( ...@@ -1027,11 +1027,10 @@ def init_worker_distributed_environment(
backend: str = "nccl", backend: str = "nccl",
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
attention_config = vllm_config.attention_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
from vllm.model_executor.layers.batch_invariant import init_batch_invariance from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance(attention_config.backend) init_batch_invariance()
override_envs_for_eplb(parallel_config) override_envs_for_eplb(parallel_config)
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment