Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0e9ee83
Unverified
Commit
a0e9ee83
authored
Jan 03, 2026
by
Alfred
Committed by
GitHub
Jan 02, 2026
Browse files
[Benchmark] Fix OOM during MoE kernel tuning for large models (#31604)
Signed-off-by:
Alfred
<
massif0601@gmail.com
>
parent
a3f2f409
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
1 deletion
+55
-1
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+55
-1
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
a0e9ee83
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
gc
import
json
import
os
import
time
...
...
@@ -26,6 +27,46 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV
=
"VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL
=
int
(
os
.
environ
.
get
(
_CACHE_CLEAR_INTERVAL_ENV
,
"50"
))
def
clear_triton_cache
():
"""Clear Triton JIT compilation cache and Python/CUDA memory.
This helps prevent OOM during tuning with large models (many experts).
"""
# Force Python garbage collection
gc
.
collect
()
# Clear CUDA memory cache
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
# Try to clear Triton's runtime cache
try
:
import
triton
if
(
hasattr
(
triton
,
"runtime"
)
and
hasattr
(
triton
.
runtime
,
"cache"
)
and
hasattr
(
triton
.
runtime
.
cache
,
"clear"
)
):
triton
.
runtime
.
cache
.
clear
()
except
ImportError
:
# Triton not installed, skip cache clearing
pass
except
AttributeError
:
# Triton version doesn't have expected cache API
pass
except
Exception
as
e
:
print
(
f
"Warning: Failed to clear Triton cache:
{
e
}
"
)
# Additional garbage collection after clearing caches
gc
.
collect
()
def
ensure_divisibility
(
numerator
,
denominator
,
text
):
"""Ensure that numerator is divisible by the denominator."""
...
...
@@ -483,7 +524,7 @@ class BenchmarkWorker:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
for
config
in
tqdm
(
search_space
):
for
idx
,
config
in
enumerate
(
tqdm
(
search_space
)
)
:
try
:
kernel_time
=
benchmark_config
(
config
,
...
...
@@ -506,6 +547,19 @@ class BenchmarkWorker:
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
# Periodically clear Triton JIT cache to prevent OOM
# This is especially important for large models with many experts
if
(
TRITON_CACHE_CLEAR_INTERVAL
>
0
and
idx
>
0
and
idx
%
TRITON_CACHE_CLEAR_INTERVAL
==
0
):
clear_triton_cache
()
# Final cleanup after tuning completes
clear_triton_cache
()
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment