Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
faedbb4d
Unverified
Commit
faedbb4d
authored
Nov 05, 2025
by
Paul Zhang
Committed by
GitHub
Nov 05, 2025
Browse files
[Feature] Extend batch invariant torch.compile to B200 (#27856)
Signed-off-by:
PaulZhang12
<
paulzhan@fb.com
>
parent
40db1944
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
17 deletions
+30
-17
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+0
-2
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+24
-15
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+6
-0
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
faedbb4d
...
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
...
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
model
=
model
,
model
=
model
,
max_num_seqs
=
1
,
max_num_seqs
=
1
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enforce_eager
=
True
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
max_model_len
=
2048
,
max_model_len
=
2048
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
...
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
...
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enable_prefix_caching
=
False
,
enable_prefix_caching
=
False
,
enforce_eager
=
True
,
# Enable for MOE models
# Enable for MOE models
# enable_expert_parallel=True,
# enable_expert_parallel=True,
)
)
vllm/model_executor/layers/batch_invariant.py
View file @
faedbb4d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
contextlib
import
functools
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
from
collections.abc
import
Callable
from
collections.abc
import
Callable
...
@@ -11,6 +10,7 @@ import torch
...
@@ -11,6 +10,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
...
@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE
=
True
_batch_invariant_MODE
=
True
_batch_invariant_LIB
=
torch
.
library
.
Library
(
"aten"
,
"IMPL"
)
_batch_invariant_LIB
=
torch
.
library
.
Library
(
"aten"
,
"IMPL"
)
# Batch invariant matmuls are no longer needed after cublas overrides
if
not
is_torch_equal_or_newer
(
"2.10.0.dev"
):
if
current_platform
.
is_device_capability
(
100
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::matmul"
,
matmul_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::matmul"
,
matmul_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::bmm"
,
bmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::linear"
,
linear_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::linear"
,
linear_batch_invariant
,
"CUDA"
)
else
:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg
=
os
.
environ
.
get
(
"CUBLAS_WORKSPACE_CONFIG"
,
None
)
_original_cublaslt_workspace_size
=
os
.
environ
.
get
(
"CUBLASLT_WORKSPACE_SIZE"
,
None
)
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":16:8"
os
.
environ
[
"CUBLASLT_WORKSPACE_SIZE"
]
=
"1"
_batch_invariant_LIB
.
impl
(
_batch_invariant_LIB
.
impl
(
"aten::_log_softmax"
,
_log_softmax_batch_invariant
,
"CUDA"
"aten::_log_softmax"
,
_log_softmax_batch_invariant
,
"CUDA"
)
)
...
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
...
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB
.
impl
(
"aten::mean.dim"
,
mean_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::mean.dim"
,
mean_batch_invariant
,
"CUDA"
)
# Also monkeypatch torch.bmm directly as a fallback
# Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB
.
impl
(
"aten::bmm"
,
bmm_batch_invariant
,
"CUDA"
)
_original_torch_bmm
=
torch
.
bmm
_original_torch_bmm
=
torch
.
bmm
torch
.
bmm
=
bmm_batch_invariant
torch
.
bmm
=
bmm_batch_invariant
...
@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
...
@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
)
)
torch
.
backends
.
cuda
.
preferred_blas_library
(
backend
=
"cublaslt"
)
torch
.
backends
.
cuda
.
preferred_blas_library
(
backend
=
"cublaslt"
)
if
not
is_torch_equal_or_newer
(
"2.10.0.dev"
):
_original_cublas_workspace_cfg
=
os
.
environ
.
get
(
"CUBLAS_WORKSPACE_CONFIG"
,
None
)
_original_cublaslt_workspace_size
=
os
.
environ
.
get
(
"CUBLASLT_WORKSPACE_SIZE"
,
None
)
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":16:8"
os
.
environ
[
"CUBLASLT_WORKSPACE_SIZE"
]
=
"1"
def
disable_batch_invariant_mode
():
def
disable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
,
_original_torch_bmm
global
_batch_invariant_MODE
,
_batch_invariant_LIB
,
_original_torch_bmm
...
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
...
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
@
functools
.
cache
def
vllm_is_batch_invariant
():
def
vllm_is_batch_invariant
():
env_key
=
"VLLM_BATCH_INVARIANT"
env_key
=
"VLLM_BATCH_INVARIANT"
is_overridden
=
False
is_overridden
=
False
...
...
vllm/utils/flashinfer.py
View file @
faedbb4d
...
@@ -19,6 +19,9 @@ import torch
...
@@ -19,6 +19,9 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
...
@@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
return `False` if TRTLLM attention is forced to be not used.
"""
"""
if
vllm_is_batch_invariant
():
logger
.
info_once
(
"VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant"
)
return
False
return
_force_use_trtllm_attention
(
envs
.
VLLM_USE_TRTLLM_ATTENTION
)
return
_force_use_trtllm_attention
(
envs
.
VLLM_USE_TRTLLM_ATTENTION
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment