Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2f78cba
Unverified
Commit
b2f78cba
authored
Oct 16, 2025
by
Bram Wasti
Committed by
GitHub
Oct 16, 2025
Browse files
[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by:
Bram Wasti
<
bwasti@meta.com
>
parent
23583ee2
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
61 additions
and
61 deletions
+61
-61
csrc/core/batch_invariant.hpp
csrc/core/batch_invariant.hpp
+4
-4
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+2
-2
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+1
-1
tests/v1/e2e/test_async_sched_and_preempt.py
tests/v1/e2e/test_async_sched_and_preempt.py
+1
-1
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+12
-12
vllm/config/model.py
vllm/config/model.py
+2
-2
vllm/config/parallel.py
vllm/config/parallel.py
+2
-2
vllm/distributed/device_communicators/all_reduce_utils.py
vllm/distributed/device_communicators/all_reduce_utils.py
+2
-2
vllm/distributed/device_communicators/symm_mem.py
vllm/distributed/device_communicators/symm_mem.py
+2
-2
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+3
-3
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-5
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+3
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-5
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+3
-3
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+2
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+2
-2
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+3
-3
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-2
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+2
-2
No files found.
csrc/core/batch_invariant.hpp
View file @
b2f78cba
...
...
@@ -5,11 +5,11 @@
namespace
vllm
{
// vllm_
kernel_override
_batch_invariant(); returns true
// if env VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT=1
inline
bool
vllm_
kernel_override
_batch_invariant
()
{
// vllm_
is
_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1
inline
bool
vllm_
is
_batch_invariant
()
{
static
bool
cached
=
[]()
{
std
::
string
env_key
=
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
;
std
::
string
env_key
=
"VLLM_BATCH_INVARIANT"
;
const
char
*
val
=
std
::
getenv
(
env_key
.
c_str
());
return
(
val
&&
std
::
atoi
(
val
)
!=
0
)
?
1
:
0
;
}();
...
...
csrc/layernorm_kernels.cu
View file @
b2f78cba
...
...
@@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
bool
batch_invariant_launch
=
vllm
::
vllm_
kernel_override
_batch_invariant
();
bool
batch_invariant_launch
=
vllm
::
vllm_
is
_batch_invariant
();
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
...
...
@@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
bool
batch_invariant_launch
=
vllm
::
vllm_
kernel_override
_batch_invariant
();
bool
batch_invariant_launch
=
vllm
::
vllm_
is
_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
...
...
csrc/layernorm_quant_kernels.cu
View file @
b2f78cba
...
...
@@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
bool
batch_invariant_launch
=
vllm
::
vllm_
kernel_override
_batch_invariant
();
bool
batch_invariant_launch
=
vllm
::
vllm_
is
_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
...
...
tests/v1/e2e/test_async_sched_and_preempt.py
View file @
b2f78cba
...
...
@@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
# m.setenv("VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT", "1")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs
:
list
[
tuple
[
str
,
list
]]
=
[]
for
test_preemption
in
[
False
,
True
]:
...
...
tests/v1/generation/test_batch_invariance.py
View file @
b2f78cba
...
...
@@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
@
pytest
.
fixture
(
autouse
=
True
)
def
enable_batch_invariant_mode
():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value
=
os
.
environ
.
get
(
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
]
=
"1"
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"1"
yield
# Restore original value after test
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
,
None
)
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
]
=
old_value
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
...
...
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
disable_custom_ar
=
vllm_
kernel_override
_batch_invariant
()
disable_custom_ar
=
vllm_
is
_batch_invariant
()
if
disable_custom_ar
:
print
(
f
"
\n
{
'='
*
80
}
"
)
...
...
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
# CRITICAL: Disable batch invariance for this test
old_value
=
os
.
environ
.
get
(
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
]
=
"0"
old_value
=
os
.
environ
.
get
(
"VLLM_BATCH_INVARIANT"
)
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"0"
try
:
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
...
...
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
finally
:
# Restore original value
if
old_value
is
None
:
os
.
environ
.
pop
(
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
,
None
)
os
.
environ
.
pop
(
"VLLM_BATCH_INVARIANT"
,
None
)
else
:
os
.
environ
[
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
]
=
old_value
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
@
hopper_only
...
...
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
disable_custom_ar
=
vllm_
kernel_override
_batch_invariant
()
disable_custom_ar
=
vllm_
is
_batch_invariant
()
if
disable_custom_ar
:
print
(
f
"
\n
{
'='
*
80
}
"
)
...
...
vllm/config/model.py
View file @
b2f78cba
...
...
@@ -21,7 +21,7 @@ from vllm.config.scheduler import RunnerType
from
vllm.config.utils
import
assert_hashable
,
config
,
getattr_iter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
(
...
...
@@ -423,7 +423,7 @@ class ModelConfig:
video_pruning_rate
:
float
|
None
,
)
->
None
:
# Enable batch invariance settings if requested
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
enforce_eager
=
True
# Set the default seed to 0 in V1.
...
...
vllm/config/parallel.py
View file @
b2f78cba
...
...
@@ -15,7 +15,7 @@ import vllm.envs as envs
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cuda_device_count_stateless
,
get_open_ports_list
...
...
@@ -565,7 +565,7 @@ class ParallelConfig:
from
vllm.executor.executor_base
import
ExecutorBase
# Enable batch invariance settings if requested
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
disable_custom_all_reduce
=
True
if
(
...
...
vllm/distributed/device_communicators/all_reduce_utils.py
View file @
b2f78cba
...
...
@@ -20,7 +20,7 @@ import vllm.envs as envs
from
vllm.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.utils
import
cuda_device_count_stateless
,
update_environment_variables
...
...
@@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled
,
)
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
return
False
if
not
is_symmetric_memory_enabled
():
...
...
vllm/distributed/device_communicators/symm_mem.py
View file @
b2f78cba
...
...
@@ -10,7 +10,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -103,7 +103,7 @@ class SymmMemCommunicator:
return
self
.
force_multimem
=
force_multimem
self
.
disabled
=
False
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
disabled
=
True
def
should_use_symm_mem
(
self
,
inp
:
torch
.
Tensor
):
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
b2f78cba
...
...
@@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
def
vllm_
kernel_override
_batch_invariant
():
env_key
=
"VLLM_
KERNEL_OVERRIDE_
BATCH_INVARIANT"
def
vllm_
is
_batch_invariant
():
env_key
=
"VLLM_BATCH_INVARIANT"
is_overridden
=
False
val
=
os
.
getenv
(
env_key
,
"0"
)
try
:
...
...
@@ -797,7 +797,7 @@ def override_envs_for_invariance():
def
init_batch_invariance
():
# this will hit all the csrc overrides as well
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
override_envs_for_invariance
()
enable_batch_invariant_mode
()
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
b2f78cba
...
...
@@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
...
...
@@ -841,7 +841,7 @@ def get_moe_configs(
"""
# Avoid optimizing for the batch invariant case. Use default config
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
return
None
# First look up if an optimized configuration is available in the configs
...
...
@@ -976,7 +976,7 @@ def get_default_config(
dtype
:
str
|
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
dict
[
str
,
int
]:
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
...
...
@@ -1136,7 +1136,7 @@ def fused_topk_bias(
)
+
e_score_correction_bias
.
unsqueeze
(
0
)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted
=
vllm_
kernel_override
_batch_invariant
()
use_sorted
=
vllm_
is
_batch_invariant
()
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
topk
,
dim
=-
1
,
sorted
=
use_sorted
)[
1
]
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
if
renormalize
:
...
...
@@ -1200,7 +1200,7 @@ def grouped_topk(
)
# [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted
=
vllm_
kernel_override
_batch_invariant
()
use_sorted
=
vllm_
is
_batch_invariant
()
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
use_sorted
)[
1
]
# [n, top_k_group]
...
...
vllm/model_executor/layers/layernorm.py
View file @
b2f78cba
...
...
@@ -10,7 +10,7 @@ import vllm.envs as envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.batch_invariant
import
(
rms_norm_batch_invariant
,
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -25,7 +25,7 @@ def rms_norm(
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
return
rms_norm_batch_invariant
(
x
,
weight
,
variance_epsilon
)
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
...
...
@@ -45,7 +45,7 @@ def fused_add_rms_norm(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
return
rms_norm_batch_invariant
(
x
+
residual
,
weight
,
variance_epsilon
),
x
+
residual
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
b2f78cba
...
...
@@ -15,7 +15,7 @@ from vllm import _custom_ops as ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
...
...
@@ -356,7 +356,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Disable marlin for rocm
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
use_marlin
=
False
self
.
use_aiter_and_is_supported
=
check_aiter_fp8_linear_support
()
...
...
@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
# Dequantize FP8 weights to BF16
weight_fp8
=
layer
.
weight
.
to
(
torch
.
bfloat16
)
weight_scale
=
layer
.
weight_scale
.
to
(
torch
.
bfloat16
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
b2f78cba
...
...
@@ -35,7 +35,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
...
...
@@ -308,7 +308,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs.
max_num_splits
=
self
.
max_num_splits
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
max_num_splits
=
1
def
schedule
(
...
...
@@ -484,7 +484,7 @@ class FlashAttentionImpl(AttentionImpl):
self
.
attn_type
=
attn_type
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
# Cache the batch invariant result for use in forward passes
self
.
batch_invariant_enabled
=
vllm_
kernel_override
_batch_invariant
()
self
.
batch_invariant_enabled
=
vllm_
is
_batch_invariant
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
not
flash_attn_supports_fp8
():
raise
NotImplementedError
(
...
...
@@ -963,7 +963,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
num_splits
=
1
if
vllm_
kernel_override
_batch_invariant
()
else
0
,
num_splits
=
1
if
vllm_
is
_batch_invariant
()
else
0
,
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
...
...
@@ -988,7 +988,7 @@ def cascade_attention(
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
num_splits
=
1
if
vllm_
kernel_override
_batch_invariant
()
else
0
,
num_splits
=
1
if
vllm_
is
_batch_invariant
()
else
0
,
)
# Merge prefix and suffix outputs, and store the result in output.
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
b2f78cba
...
...
@@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
...
...
@@ -291,7 +291,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
self
.
_decode_wrapper
=
None
# Wrapper for decode (general shape)
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
decode_fixed_split_size
=
2048
self
.
prefill_fixed_split_size
=
4096
self
.
disable_split_kv
=
True
...
...
@@ -404,7 +404,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
buffer_size
=
FLASHINFER_WORKSPACE_BUFFER_SIZE
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
buffer_size
=
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self
.
_workspace_buffer
=
torch
.
zeros
(
buffer_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
b2f78cba
...
...
@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.utils
import
cdiv
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
...
...
@@ -863,7 +863,7 @@ def get_kernel_options(
kernel_options
:
dict
[
str
,
int
|
bool
]
=
{
"FORCE_USE_FLEX_ATTENTION"
:
True
,
}
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
kernel_options
[
"BLOCK_M"
]
=
16
kernel_options
[
"BLOCK_N"
]
=
16
kernel_options
[
"IS_DIVISIBLE"
]
=
False
...
...
vllm/v1/attention/backends/mla/common.py
View file @
b2f78cba
...
...
@@ -212,7 +212,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -1283,7 +1283,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs
[
"return_attn_probs"
]
=
return_softmax_lse
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
kwargs
[
"num_splits"
]
=
1
attn_out
=
self
.
flash_attn_varlen_func
(
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
b2f78cba
...
...
@@ -19,7 +19,7 @@ from vllm.attention.utils.fa_utils import (
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -110,7 +110,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture.
self
.
max_num_splits
=
envs
.
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
self
.
max_num_splits
=
1
def
_schedule_decode
(
...
...
@@ -181,7 +181,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs.
max_num_splits
=
self
.
max_num_splits
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
max_num_splits
=
1
metadata
=
FlashAttnMLADecodeMetadata
(
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
b2f78cba
...
...
@@ -15,7 +15,7 @@ from vllm.attention.ops.flashmla import (
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
num_splits
=
attn_metadata
.
decode
.
num_splits
if
vllm_
kernel_override
_batch_invariant
():
if
vllm_
is
_batch_invariant
():
device
=
q
.
device
dtype
=
torch
.
int32
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
b2f78cba
...
...
@@ -14,7 +14,7 @@ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_
kernel_override
_batch_invariant
,
vllm_
is
_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
...
...
@@ -163,7 +163,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse
=
torch
.
zeros
(
B
,
q_num_heads
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits
=
1
if
vllm_
kernel_override
_batch_invariant
()
else
4
num_kv_splits
=
1
if
vllm_
is
_batch_invariant
()
else
4
# TODO(lucas) Allocate ahead of time
attn_logits
=
torch
.
empty
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment