Unverified Commit dc6908ac authored by Ranran's avatar Ranran Committed by GitHub
Browse files

[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)


Signed-off-by: default avatarRanran <1012869439@qq.com>
Signed-off-by: default avatarRanran <hzz5361@psu.edu>
Signed-off-by: default avatarran <hzz5361@psu.edu>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent e85f8f09
...@@ -305,9 +305,7 @@ def _flashinfer_fp8_blockscale_gemm_impl( ...@@ -305,9 +305,7 @@ def _flashinfer_fp8_blockscale_gemm_impl(
) )
return output return output
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant if envs.VLLM_BATCH_INVARIANT:
if vllm_is_batch_invariant():
return run_deepgemm(input, weight, weight_scale) return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32 condition = input.shape[0] < 32
......
...@@ -19,9 +19,6 @@ import torch ...@@ -19,9 +19,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -289,7 +286,7 @@ def supports_trtllm_attention() -> bool: ...@@ -289,7 +286,7 @@ def supports_trtllm_attention() -> bool:
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
""" """
# Batch-invariant mode disables TRTLLM attention # Batch-invariant mode disables TRTLLM attention
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
return False return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
from typing import Any from typing import Any
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -114,7 +114,7 @@ def get_flash_attn_version( ...@@ -114,7 +114,7 @@ def get_flash_attn_version(
# FA4 currently uses batch-shape-dependent scheduling # FA4 currently uses batch-shape-dependent scheduling
# heuristics on SM100+, which breaks batch invariance. # heuristics on SM100+, which breaks batch invariance.
if vllm_is_batch_invariant() and fa_version == 4: if envs.VLLM_BATCH_INVARIANT and fa_version == 4:
logger.warning_once( logger.warning_once(
"Cannot use FA version 4 with batch invariance, " "Cannot use FA version 4 with batch invariance, "
"defaulting to FA version 2.", "defaulting to FA version 2.",
......
...@@ -33,6 +33,7 @@ if is_flash_attn_varlen_func_available(): ...@@ -33,6 +33,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata, get_scheduler_metadata,
reshape_and_cache_flash, reshape_and_cache_flash,
) )
import vllm.envs as envs
from vllm.config import ( from vllm.config import (
VllmConfig, VllmConfig,
get_current_vllm_config, get_current_vllm_config,
...@@ -42,9 +43,6 @@ from vllm.config import ( ...@@ -42,9 +43,6 @@ from vllm.config import (
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -402,7 +400,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -402,7 +400,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
max_num_splits = 1 max_num_splits = 1
def schedule( def schedule(
...@@ -601,7 +599,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -601,7 +599,7 @@ class FlashAttentionImpl(AttentionImpl):
scope="local", scope="local",
) )
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant() self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError( raise NotImplementedError(
...@@ -1124,7 +1122,7 @@ def cascade_attention( ...@@ -1124,7 +1122,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel, # s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge. # enabling its effect during the final attention merge.
s_aux=s_aux, s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
...@@ -1149,7 +1147,7 @@ def cascade_attention( ...@@ -1149,7 +1147,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
......
...@@ -28,9 +28,6 @@ from vllm.config import ( ...@@ -28,9 +28,6 @@ from vllm.config import (
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
...@@ -544,7 +541,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -544,7 +541,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) = None # Wrapper for prefill/append ) = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape) self._decode_wrapper = None # Wrapper for decode (general shape)
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
self.decode_fixed_split_size = 2048 self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096 self.prefill_fixed_split_size = 4096
self.disable_split_kv = True self.disable_split_kv = True
...@@ -719,7 +716,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -719,7 +716,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros( self._workspace_buffer = torch.zeros(
buffer_size, dtype=torch.uint8, device=self.device buffer_size, dtype=torch.uint8, device=self.device
......
...@@ -20,12 +20,10 @@ from torch.nn.attention.flex_attention import ( ...@@ -20,12 +20,10 @@ from torch.nn.attention.flex_attention import (
or_masks, or_masks,
) )
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -995,7 +993,7 @@ def get_kernel_options( ...@@ -995,7 +993,7 @@ def get_kernel_options(
return block_size return block_size
return candidate return candidate
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
kernel_options["BLOCK_M"] = 16 kernel_options["BLOCK_M"] = 16
kernel_options["BLOCK_N"] = 16 kernel_options["BLOCK_N"] = 16
kernel_options["IS_DIVISIBLE"] = False kernel_options["IS_DIVISIBLE"] = False
......
...@@ -6,6 +6,7 @@ from typing import ClassVar ...@@ -6,6 +6,7 @@ from typing import ClassVar
import torch import torch
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -152,7 +150,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -152,7 +150,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
) )
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
self.max_num_splits = 1 self.max_num_splits = 1
def _schedule_decode( def _schedule_decode(
...@@ -209,7 +207,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -209,7 +207,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
if vllm_is_batch_invariant(): if envs.VLLM_BATCH_INVARIANT:
max_num_splits = 1 max_num_splits = 1
scheduler_metadata = self._schedule_decode( scheduler_metadata = self._schedule_decode(
......
...@@ -6,6 +6,7 @@ from typing import ClassVar ...@@ -6,6 +6,7 @@ from typing import ClassVar
import torch import torch
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -256,7 +254,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -256,7 +254,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = reshape_query_for_spec_decode(q, num_decodes) q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata scheduler_metadata = attn_metadata.decode.scheduler_metadata
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"): if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"):
device = q.device device = q.device
dtype = torch.int32 dtype = torch.int32
......
...@@ -5,6 +5,7 @@ from typing import ClassVar ...@@ -5,6 +5,7 @@ from typing import ClassVar
import torch import torch
import vllm.envs as envs
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import ( from vllm.model_executor.layers.attention.mla_attention import (
...@@ -12,9 +13,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -12,9 +13,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
...@@ -151,7 +149,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -151,7 +149,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction # For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_is_batch_invariant() else 4 num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
# TODO(lucas) Allocate ahead of time # TODO(lucas) Allocate ahead of time
attn_logits = torch.empty( attn_logits = torch.empty(
......
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
logger = init_logger(__name__) logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant() is_batch_invariant = envs.VLLM_BATCH_INVARIANT
float8_info = torch.finfo(current_platform.fp8_dtype()) float8_info = torch.finfo(current_platform.fp8_dtype())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment