Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
faedbb4d
Unverified
Commit
faedbb4d
authored
Nov 05, 2025
by
Paul Zhang
Committed by
GitHub
Nov 05, 2025
Browse files
[Feature] Extend batch invariant torch.compile to B200 (#27856)
Signed-off-by:
PaulZhang12
<
paulzhan@fb.com
>
parent
40db1944
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
17 deletions
+30
-17
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+0
-2
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+24
-15
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+6
-0
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
faedbb4d
...
...
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
model
=
model
,
max_num_seqs
=
1
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enforce_eager
=
True
,
gpu_memory_utilization
=
0.9
,
max_model_len
=
2048
,
dtype
=
"bfloat16"
,
...
...
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
dtype
=
"bfloat16"
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enable_prefix_caching
=
False
,
enforce_eager
=
True
,
# Enable for MOE models
# enable_expert_parallel=True,
)
vllm/model_executor/layers/batch_invariant.py
View file @
faedbb4d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
functools
import
os
from
collections
import
namedtuple
from
collections.abc
import
Callable
...
...
@@ -11,6 +10,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE
=
True
_batch_invariant_LIB
=
torch
.
library
.
Library
(
"aten"
,
"IMPL"
)
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::matmul"
,
matmul_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::bmm"
,
bmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::linear"
,
linear_batch_invariant
,
"CUDA"
)
# Batch invariant matmuls are no longer needed after cublas overrides
if
not
is_torch_equal_or_newer
(
"2.10.0.dev"
):
if
current_platform
.
is_device_capability
(
100
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::matmul"
,
matmul_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::linear"
,
linear_batch_invariant
,
"CUDA"
)
else
:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg
=
os
.
environ
.
get
(
"CUBLAS_WORKSPACE_CONFIG"
,
None
)
_original_cublaslt_workspace_size
=
os
.
environ
.
get
(
"CUBLASLT_WORKSPACE_SIZE"
,
None
)
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":16:8"
os
.
environ
[
"CUBLASLT_WORKSPACE_SIZE"
]
=
"1"
_batch_invariant_LIB
.
impl
(
"aten::_log_softmax"
,
_log_softmax_batch_invariant
,
"CUDA"
)
...
...
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB
.
impl
(
"aten::mean.dim"
,
mean_batch_invariant
,
"CUDA"
)
# Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB
.
impl
(
"aten::bmm"
,
bmm_batch_invariant
,
"CUDA"
)
_original_torch_bmm
=
torch
.
bmm
torch
.
bmm
=
bmm_batch_invariant
...
...
@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
)
torch
.
backends
.
cuda
.
preferred_blas_library
(
backend
=
"cublaslt"
)
if
not
is_torch_equal_or_newer
(
"2.10.0.dev"
):
_original_cublas_workspace_cfg
=
os
.
environ
.
get
(
"CUBLAS_WORKSPACE_CONFIG"
,
None
)
_original_cublaslt_workspace_size
=
os
.
environ
.
get
(
"CUBLASLT_WORKSPACE_SIZE"
,
None
)
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":16:8"
os
.
environ
[
"CUBLASLT_WORKSPACE_SIZE"
]
=
"1"
def
disable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
,
_original_torch_bmm
...
...
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
@
functools
.
cache
def
vllm_is_batch_invariant
():
env_key
=
"VLLM_BATCH_INVARIANT"
is_overridden
=
False
...
...
vllm/utils/flashinfer.py
View file @
faedbb4d
...
...
@@ -19,6 +19,9 @@ import torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
if
vllm_is_batch_invariant
():
logger
.
info_once
(
"VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant"
)
return
False
return
_force_use_trtllm_attention
(
envs
.
VLLM_USE_TRTLLM_ATTENTION
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment