Unverified Commit dc6908ac authored by Ranran's avatar Ranran Committed by GitHub
Browse files

[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)


Signed-off-by: default avatarRanran <1012869439@qq.com>
Signed-off-by: default avatarRanran <hzz5361@psu.edu>
Signed-off-by: default avatarran <hzz5361@psu.edu>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent e85f8f09
...@@ -55,37 +55,37 @@ def _clear_supports_cache(): ...@@ -55,37 +55,37 @@ def _clear_supports_cache():
# supports_trtllm_attention # supports_trtllm_attention
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True) @patch("vllm.envs.VLLM_BATCH_INVARIANT", True)
def test_supports_batch_invariant_disables(_mock): def test_supports_batch_invariant_disables():
assert supports_trtllm_attention() is False assert supports_trtllm_attention() is False
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) @patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch( @patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family", "vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True, return_value=True,
) )
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
def test_supports_sm100_with_artifactory(_art, _cap, _bi): def test_supports_sm100_with_artifactory(_art, _cap):
assert supports_trtllm_attention() is True assert supports_trtllm_attention() is True
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) @patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch( @patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family", "vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=False, return_value=False,
) )
def test_supports_non_sm100_platform(_cap, _bi): def test_supports_non_sm100_platform(_cap):
assert supports_trtllm_attention() is False assert supports_trtllm_attention() is False
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) @patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch( @patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family", "vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True, return_value=True,
) )
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False) @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
def test_supports_sm100_without_artifactory(_art, _cap, _bi): def test_supports_sm100_without_artifactory(_art, _cap):
assert supports_trtllm_attention() is False assert supports_trtllm_attention() is False
......
...@@ -8,7 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. ...@@ -8,7 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.batch_invariant as batch_invariant import vllm.envs as envs
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
VllmConfig, VllmConfig,
...@@ -69,7 +69,7 @@ def test_grouped_topk( ...@@ -69,7 +69,7 @@ def test_grouped_topk(
with set_current_vllm_config(vllm_config), monkeypatch.context() as m: with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) m.setattr(envs, "VLLM_BATCH_INVARIANT", True)
grouped_topk = GroupedTopk( grouped_topk = GroupedTopk(
topk=topk, topk=topk,
renormalize=renormalize, renormalize=renormalize,
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import vllm.model_executor.layers.batch_invariant as batch_invariant import vllm.envs as envs
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests.""" """Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
...@@ -15,7 +15,7 @@ from utils import ( ...@@ -15,7 +15,7 @@ from utils import (
skip_unsupported, skip_unsupported,
) )
import vllm.model_executor.layers.batch_invariant as batch_invariant import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
...@@ -173,11 +173,9 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -173,11 +173,9 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
# For batch invariance, disable custom all-reduce to ensure deterministic # For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic) # all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import ( import vllm.envs as envs
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_is_batch_invariant() disable_custom_ar = envs.VLLM_BATCH_INVARIANT
if disable_custom_ar: if disable_custom_ar:
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
...@@ -454,7 +452,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -454,7 +452,7 @@ def test_logprobs_without_batch_invariance_should_fail(
""" """
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
...@@ -674,11 +672,9 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -674,11 +672,9 @@ def test_decode_logprobs_match_prefill_logprobs(
random.seed(seed) random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import ( import vllm.envs as envs
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_is_batch_invariant() disable_custom_ar = envs.VLLM_BATCH_INVARIANT
if disable_custom_ar: if disable_custom_ar:
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
......
...@@ -14,9 +14,6 @@ from typing_extensions import Self ...@@ -14,9 +14,6 @@ from typing_extensions import Self
import vllm.envs as envs import vllm.envs as envs
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
...@@ -786,7 +783,7 @@ class ParallelConfig: ...@@ -786,7 +783,7 @@ class ParallelConfig:
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
# Enable batch invariance settings if requested # Enable batch invariance settings if requested
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
self.disable_custom_all_reduce = True self.disable_custom_all_reduce = True
if ( if (
......
...@@ -1112,11 +1112,9 @@ class VllmConfig: # type: ignore[misc] ...@@ -1112,11 +1112,9 @@ class VllmConfig: # type: ignore[misc]
"when cudagraph_mode piecewise cudagraphs is used, " "when cudagraph_mode piecewise cudagraphs is used, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}" f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
) )
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
if ( if (
self.model_config self.model_config
and vllm_is_batch_invariant() and envs.VLLM_BATCH_INVARIANT
and not self.model_config.disable_cascade_attn and not self.model_config.disable_cascade_attn
): ):
self.model_config.disable_cascade_attn = True self.model_config.disable_cascade_attn = True
......
...@@ -19,9 +19,6 @@ import torch.multiprocessing as mp ...@@ -19,9 +19,6 @@ import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
...@@ -115,7 +112,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) ...@@ -115,7 +112,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled, is_symmetric_memory_enabled,
) )
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return False return False
if not is_symmetric_memory_enabled(): if not is_symmetric_memory_enabled():
......
...@@ -5,13 +5,11 @@ import torch ...@@ -5,13 +5,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import ( from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES, SYMM_MEM_ALL_REDUCE_MAX_SIZES,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
try: try:
...@@ -112,7 +110,7 @@ class SymmMemCommunicator: ...@@ -112,7 +110,7 @@ class SymmMemCommunicator:
return return
self.force_multimem = force_multimem self.force_multimem = force_multimem
self.disabled = False self.disabled = False
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
self.disabled = True self.disabled = True
def should_use_symm_mem(self, inp: torch.Tensor): def should_use_symm_mem(self, inp: torch.Tensor):
......
...@@ -74,6 +74,7 @@ if TYPE_CHECKING: ...@@ -74,6 +74,7 @@ if TYPE_CHECKING:
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_BATCH_INVARIANT: bool = False
MAX_JOBS: str | None = None MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
...@@ -280,9 +281,6 @@ def disable_compile_cache() -> bool: ...@@ -280,9 +281,6 @@ def disable_compile_cache() -> bool:
def use_aot_compile() -> bool: def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = ( default_value = (
...@@ -292,7 +290,7 @@ def use_aot_compile() -> bool: ...@@ -292,7 +290,7 @@ def use_aot_compile() -> bool:
) )
return ( return (
not vllm_is_batch_invariant() not bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0")))
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
) )
...@@ -498,6 +496,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -498,6 +496,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
["highest", "high", "medium"], ["highest", "high", "medium"],
case_sensitive=False, case_sensitive=False,
), ),
# Enable batch-invariant mode: deterministic results regardless of
# batch composition. Requires NVIDIA GPU with compute capability >= 9.0.
"VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))),
# Maximum number of compilation jobs to run in parallel. # Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs # By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
......
...@@ -11,12 +11,11 @@ import torch ...@@ -11,12 +11,11 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
logger = init_logger(__name__) logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant() is_batch_invariant = envs.VLLM_BATCH_INVARIANT
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
......
...@@ -6,7 +6,6 @@ from collections.abc import Sequence ...@@ -6,7 +6,6 @@ from collections.abc import Sequence
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy, process_fp8_weight_block_strategy,
) )
...@@ -42,7 +41,7 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): ...@@ -42,7 +41,7 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
# Check if platform supports FP8 Marlin # Check if platform supports FP8 Marlin
if not is_fp8_marlin_supported(): if not is_fp8_marlin_supported():
return False, "FP8 Marlin requires compute capability 7.5 or higher" return False, "FP8 Marlin requires compute capability 7.5 or higher"
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return False, "FP8 Marlin not supported for batch invariant execution." return False, "FP8 Marlin not supported for batch invariant execution."
if ( if (
compute_capability is not None compute_capability is not None
......
...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( ...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer, maybe_transfer_kv_layer,
) )
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod, UnquantizedLinearMethod,
) )
...@@ -296,7 +295,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -296,7 +295,7 @@ class Attention(nn.Module, AttentionLayerBase):
if ( if (
cache_config is not None cache_config is not None
and cache_config.enable_prefix_caching and cache_config.enable_prefix_caching
and vllm_is_batch_invariant() and envs.VLLM_BATCH_INVARIANT
and ( and (
self.attn_backend.get_name() == "FLASHINFER" self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA" or self.attn_backend.get_name() == "TRITON_MLA"
......
...@@ -227,7 +227,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( ...@@ -227,7 +227,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer, maybe_transfer_kv_layer,
) )
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
) )
...@@ -372,7 +371,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -372,7 +371,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if ( if (
cache_config is not None cache_config is not None
and cache_config.enable_prefix_caching and cache_config.enable_prefix_caching
and vllm_is_batch_invariant() and envs.VLLM_BATCH_INVARIANT
and ( and (
self.attn_backend.get_name() == "TRITON_MLA" self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER" or self.attn_backend.get_name() == "FLASHINFER"
...@@ -2188,7 +2187,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2188,7 +2187,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter # ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse # called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1 kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func( attn_out = self.flash_attn_varlen_func(
......
...@@ -6,6 +6,7 @@ from typing import Any ...@@ -6,6 +6,7 @@ from typing import Any
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -986,21 +987,6 @@ def enable_batch_invariant_mode(): ...@@ -986,21 +987,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt") torch.backends.cuda.preferred_blas_library(backend="cublaslt")
def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
return int(val) != 0
except ValueError:
return False
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance( def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None, attention_backend: AttentionBackendEnum | None,
): ):
...@@ -1059,7 +1045,7 @@ def init_batch_invariance( ...@@ -1059,7 +1045,7 @@ def init_batch_invariance(
attention_backend: AttentionBackendEnum | None, attention_backend: AttentionBackendEnum | None,
): ):
# this will hit all the csrc overrides as well # this will hit all the csrc overrides as well
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
override_envs_for_invariance(attention_backend) override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode() enable_batch_invariant_mode()
......
...@@ -14,9 +14,6 @@ import vllm.envs as envs ...@@ -14,9 +14,6 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.activation import ( from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation, MoEActivation,
apply_moe_activation, apply_moe_activation,
...@@ -1051,7 +1048,7 @@ def get_moe_configs( ...@@ -1051,7 +1048,7 @@ def get_moe_configs(
""" """
# Avoid optimizing for the batch invariant case. Use default config # Avoid optimizing for the batch invariant case. Use default config
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return None return None
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
...@@ -1232,7 +1229,7 @@ def get_default_config( ...@@ -1232,7 +1229,7 @@ def get_default_config(
dtype: str | None, dtype: str | None,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
) -> dict[str, int]: ) -> dict[str, int]:
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return { return {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
......
...@@ -6,11 +6,9 @@ from collections.abc import Callable ...@@ -6,11 +6,9 @@ from collections.abc import Callable
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
get_routing_method_type, get_routing_method_type,
...@@ -160,7 +158,7 @@ def fused_topk_bias( ...@@ -160,7 +158,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0) ) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection # For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant() use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices) topk_weights = scores.gather(1, topk_indices)
if renormalize: if renormalize:
......
...@@ -10,9 +10,6 @@ from vllm import envs as envs ...@@ -10,9 +10,6 @@ from vllm import envs as envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
get_routing_method_type, get_routing_method_type,
...@@ -135,7 +132,7 @@ def grouped_topk( ...@@ -135,7 +132,7 @@ def grouped_topk(
) # [n, n_group] ) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection # For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant() use_sorted = envs.VLLM_BATCH_INVARIANT
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1 1
] # [n, top_k_group] ] # [n, top_k_group]
......
...@@ -12,7 +12,6 @@ from vllm.logger import init_logger ...@@ -12,7 +12,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant, rms_norm_batch_invariant,
vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -57,7 +56,7 @@ def rms_norm( ...@@ -57,7 +56,7 @@ def rms_norm(
) -> torch.Tensor: ) -> torch.Tensor:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon) return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x) out = torch.empty_like(x)
ops.rms_norm( ops.rms_norm(
...@@ -77,7 +76,7 @@ def fused_add_rms_norm( ...@@ -77,7 +76,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant( return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon x + residual, weight, variance_epsilon
), x + residual ), x + residual
...@@ -300,7 +299,7 @@ class RMSNorm(CustomOp): ...@@ -300,7 +299,7 @@ class RMSNorm(CustomOp):
and x.is_cuda and x.is_cuda
and x.dim() >= 2 and x.dim() >= 2
and self.has_weight and self.has_weight
and not vllm_is_batch_invariant() and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous() and self.weight.data.is_contiguous()
): ):
...@@ -328,7 +327,7 @@ class RMSNorm(CustomOp): ...@@ -328,7 +327,7 @@ class RMSNorm(CustomOp):
and x.dtype == residual.dtype and x.dtype == residual.dtype
and x.dim() >= 2 and x.dim() >= 2
and self.has_weight and self.has_weight
and not vllm_is_batch_invariant() and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous() and self.weight.data.is_contiguous()
): ):
......
...@@ -7,6 +7,7 @@ from abc import abstractmethod ...@@ -7,6 +7,7 @@ from abc import abstractmethod
import torch import torch
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
import vllm.envs as envs
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -19,7 +20,6 @@ from vllm.logger import init_logger ...@@ -19,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
linear_batch_invariant, linear_batch_invariant,
vllm_is_batch_invariant,
) )
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -223,7 +223,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -223,7 +223,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias) return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
...@@ -17,9 +18,6 @@ from vllm.model_executor.kernels.linear import ( ...@@ -17,9 +18,6 @@ from vllm.model_executor.kernels.linear import (
) )
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -441,7 +439,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -441,7 +439,7 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported. # we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
if self.block_quant: if self.block_quant:
assert self.weight_block_size is not None assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply( return self.w8a8_block_fp8_linear.apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment