Unverified Commit dc6908ac authored by Ranran's avatar Ranran Committed by GitHub
Browse files

[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)


Signed-off-by: default avatarRanran <1012869439@qq.com>
Signed-off-by: default avatarRanran <hzz5361@psu.edu>
Signed-off-by: default avatarran <hzz5361@psu.edu>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent e85f8f09
......@@ -55,37 +55,37 @@ def _clear_supports_cache():
# supports_trtllm_attention
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True)
def test_supports_batch_invariant_disables(_mock):
@patch("vllm.envs.VLLM_BATCH_INVARIANT", True)
def test_supports_batch_invariant_disables():
assert supports_trtllm_attention() is False
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
def test_supports_sm100_with_artifactory(_art, _cap, _bi):
def test_supports_sm100_with_artifactory(_art, _cap):
assert supports_trtllm_attention() is True
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=False,
)
def test_supports_non_sm100_platform(_cap, _bi):
def test_supports_non_sm100_platform(_cap):
assert supports_trtllm_attention() is False
@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
def test_supports_sm100_without_artifactory(_art, _cap, _bi):
def test_supports_sm100_without_artifactory(_art, _cap):
assert supports_trtllm_attention() is False
......
......@@ -8,7 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest
import torch
import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs
from vllm.config import (
CompilationConfig,
VllmConfig,
......@@ -69,7 +69,7 @@ def test_grouped_topk(
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
m.setattr(envs, "VLLM_BATCH_INVARIANT", True)
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
......
......@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
......@@ -15,7 +15,7 @@ from utils import (
skip_unsupported,
)
import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs
from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
......@@ -173,11 +173,9 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
import vllm.envs as envs
disable_custom_ar = vllm_is_batch_invariant()
disable_custom_ar = envs.VLLM_BATCH_INVARIANT
if disable_custom_ar:
print(f"\n{'=' * 80}")
......@@ -454,7 +452,7 @@ def test_logprobs_without_batch_invariance_should_fail(
"""
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
......@@ -674,11 +672,9 @@ def test_decode_logprobs_match_prefill_logprobs(
random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
import vllm.envs as envs
disable_custom_ar = vllm_is_batch_invariant()
disable_custom_ar = envs.VLLM_BATCH_INVARIANT
if disable_custom_ar:
print(f"\n{'=' * 80}")
......
......@@ -14,9 +14,6 @@ from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
......@@ -786,7 +783,7 @@ class ParallelConfig:
from vllm.v1.executor import Executor
# Enable batch invariance settings if requested
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.disable_custom_all_reduce = True
if (
......
......@@ -1112,11 +1112,9 @@ class VllmConfig: # type: ignore[misc]
"when cudagraph_mode piecewise cudagraphs is used, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
if (
self.model_config
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and not self.model_config.disable_cascade_attn
):
self.model_config.disable_cascade_attn = True
......
......@@ -19,9 +19,6 @@ import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
......@@ -115,7 +112,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled,
)
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return False
if not is_symmetric_memory_enabled():
......
......@@ -5,13 +5,11 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
try:
......@@ -112,7 +110,7 @@ class SymmMemCommunicator:
return
self.force_multimem = force_multimem
self.disabled = False
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.disabled = True
def should_use_symm_mem(self, inp: torch.Tensor):
......
......@@ -74,6 +74,7 @@ if TYPE_CHECKING:
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_BATCH_INVARIANT: bool = False
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
......@@ -280,9 +281,6 @@ def disable_compile_cache() -> bool:
def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
......@@ -292,7 +290,7 @@ def use_aot_compile() -> bool:
)
return (
not vllm_is_batch_invariant()
not bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0")))
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
)
......@@ -498,6 +496,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
["highest", "high", "medium"],
case_sensitive=False,
),
# Enable batch-invariant mode: deterministic results regardless of
# batch composition. Requires NVIDIA GPU with compute capability >= 9.0.
"VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
......
......@@ -11,12 +11,11 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
......
......@@ -6,7 +6,6 @@ from collections.abc import Sequence
import torch
import vllm.envs as envs
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy,
)
......@@ -42,7 +41,7 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
# Check if platform supports FP8 Marlin
if not is_fp8_marlin_supported():
return False, "FP8 Marlin requires compute capability 7.5 or higher"
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return False, "FP8 Marlin not supported for batch invariant execution."
if (
compute_capability is not None
......
......@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
......@@ -296,7 +295,7 @@ class Attention(nn.Module, AttentionLayerBase):
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
......
......@@ -227,7 +227,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
)
......@@ -372,7 +371,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
......@@ -2188,7 +2187,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(
......
......@@ -6,6 +6,7 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -986,21 +987,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
return int(val) != 0
except ValueError:
return False
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
......@@ -1059,7 +1045,7 @@ def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode()
......
......@@ -14,9 +14,6 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
......@@ -1051,7 +1048,7 @@ def get_moe_configs(
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return None
# First look up if an optimized configuration is available in the configs
......@@ -1232,7 +1229,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
......
......@@ -6,11 +6,9 @@ from collections.abc import Callable
import torch
import vllm._custom_ops as ops
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
......@@ -160,7 +158,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
......
......@@ -10,9 +10,6 @@ from vllm import envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
......@@ -135,7 +132,7 @@ def grouped_topk(
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
use_sorted = envs.VLLM_BATCH_INVARIANT
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
......
......@@ -12,7 +12,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
......@@ -57,7 +56,7 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
......@@ -77,7 +76,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
......@@ -300,7 +299,7 @@ class RMSNorm(CustomOp):
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
......@@ -328,7 +327,7 @@ class RMSNorm(CustomOp):
and x.dtype == residual.dtype
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
......
......@@ -7,6 +7,7 @@ from abc import abstractmethod
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
import vllm.envs as envs
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
......@@ -19,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
linear_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
......@@ -223,7 +223,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
......@@ -7,6 +7,7 @@ import torch
from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
......@@ -17,9 +18,6 @@ from vllm.model_executor.kernels.linear import (
)
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
......@@ -441,7 +439,7 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor:
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment