Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eae9a9fb
Unverified
Commit
eae9a9fb
authored
Oct 10, 2025
by
Stefan He
Committed by
GitHub
Oct 10, 2025
Browse files
Fix batch invariant ops (#11368)
parent
2674c1d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
6 deletions
+168
-6
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+4
-6
test/srt/batch_invariant/test_batch_invariant_ops.py
test/srt/batch_invariant/test_batch_invariant_ops.py
+163
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
eae9a9fb
...
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
...
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
num_tiles
=
num_pid_m
*
num_pid_n
num_tiles
=
num_pid_m
*
num_pid_n
tile_id_c
=
start_pid
-
NUM_SMS
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
...
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
...
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
)
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
tile_id_c
+=
NUM_SMS
pid_m
,
pid_n
=
_compute_pid
(
tile_id_c
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
C_LARGE
:
if
C_LARGE
:
...
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
...
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
accumulator
+=
bias
accumulator
+=
bias
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
float32
:
c
=
accumulator
.
to
(
tl
.
float32
)
else
:
else
:
c
=
accumulator
.
to
(
tl
.
float16
)
c
=
accumulator
.
to
(
tl
.
float16
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
...
...
test/srt/batch_invariant/test_batch_invariant_ops.py
0 → 100644
View file @
eae9a9fb
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
import
math
import
unittest
import
torch
from
sglang.srt.batch_invariant_ops.batch_invariant_ops
import
set_batch_invariant_mode
from
sglang.test.test_utils
import
CustomTestCase
device_type
=
getattr
(
torch
.
accelerator
.
current_accelerator
(),
"type"
,
"cpu"
)
torch
.
set_default_device
(
device_type
)
# Just to get the logging out of the way
with
set_batch_invariant_mode
(
True
):
pass
class
TestBatchInvariantOps
(
CustomTestCase
):
def
_test_batch_invariance
(
self
,
M
,
K
,
N
,
dtype
):
"""
Test that matrix operations produce identical results for:
- Method 1: Matrix-vector multiplication (batch size 1)
- Method 2: Matrix-matrix multiplication, then slice (full batch)
"""
a
=
torch
.
linspace
(
-
100
,
100
,
M
*
K
,
dtype
=
dtype
).
reshape
(
M
,
K
)
# Create non-contiguous tensor
b
=
torch
.
linspace
(
-
100
,
100
,
K
*
N
,
dtype
=
dtype
).
reshape
(
N
,
K
)
b
=
b
.
transpose
(
0
,
1
)
# Method 1: Matrix-vector multiplication (batch size 1)
out1
=
torch
.
mm
(
a
[:
1
],
b
)
# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2_pre
=
torch
.
mm
(
a
,
b
)
out2
=
out2_pre
[:
1
]
# Check if results are identical
diff
=
(
out1
-
out2
).
abs
().
max
()
return
diff
.
item
()
def
_run_multiple_iterations
(
self
,
iters
,
M
,
K
,
N
,
dtype
):
"""Run multiple iterations and collect diff statistics"""
difflist
=
[]
for
_
in
range
(
iters
):
diff
=
self
.
_test_batch_invariance
(
M
,
K
,
N
,
dtype
)
difflist
.
append
(
diff
)
return
difflist
def
_assert_batch_invariant_results
(
self
,
difflist
,
dtype
,
test_name
):
"""
Assert that in batch-invariant mode:
1. All diffs must not be NaN
2. All diffs must be exactly 0
3. Max, min, and diff of diffs must all be 0
"""
max_diff
=
max
(
difflist
)
min_diff
=
min
(
difflist
)
diff_range
=
max_diff
-
min_diff
# Check for NaN values
self
.
assertFalse
(
math
.
isnan
(
max_diff
),
f
"
{
test_name
}
: max_diff is NaN for
{
dtype
}
"
)
self
.
assertFalse
(
math
.
isnan
(
min_diff
),
f
"
{
test_name
}
: min_diff is NaN for
{
dtype
}
"
)
self
.
assertFalse
(
math
.
isnan
(
diff_range
),
f
"
{
test_name
}
: diff_range is NaN for
{
dtype
}
"
)
# Check that all diffs are exactly 0
self
.
assertEqual
(
max_diff
,
0.0
,
f
"
{
test_name
}
: max_diff must be 0 in batch-invariant mode, got
{
max_diff
}
for
{
dtype
}
"
,
)
self
.
assertEqual
(
min_diff
,
0.0
,
f
"
{
test_name
}
: min_diff must be 0 in batch-invariant mode, got
{
min_diff
}
for
{
dtype
}
"
,
)
self
.
assertEqual
(
diff_range
,
0.0
,
f
"
{
test_name
}
: diff_range must be 0 in batch-invariant mode, got
{
diff_range
}
for
{
dtype
}
"
,
)
def
test_small_matrices
(
self
):
"""Test batch invariance with small matrix sizes"""
test_cases
=
[
(
"Small-1"
,
8
,
64
,
128
),
(
"Small-2"
,
16
,
128
,
256
),
(
"Small-3"
,
4
,
32
,
64
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_medium_matrices
(
self
):
"""Test batch invariance with medium matrix sizes"""
test_cases
=
[
(
"Medium-1"
,
32
,
128
,
1024
),
(
"Medium-2"
,
64
,
512
,
2048
),
(
"Medium-3"
,
24
,
192
,
768
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_large_matrices
(
self
):
"""Test batch invariance with large matrix sizes"""
test_cases
=
[
(
"Large-1"
,
128
,
1024
,
4096
),
(
"Large-2"
,
256
,
2048
,
8192
),
(
"Large-3"
,
96
,
768
,
3072
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_without_batch_invariant_mode
(
self
):
"""
Test that without batch-invariant mode, results may differ.
This test demonstrates the difference batch-invariant mode makes.
"""
M
,
K
,
N
=
32
,
128
,
1024
dtype
=
torch
.
float32
# Run without batch-invariant mode
with
set_batch_invariant_mode
(
False
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
print
(
f
"Without batch-invariant mode, we get diffs:
{
difflist
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
eae9a9fb
...
@@ -33,6 +33,7 @@ suites = {
...
@@ -33,6 +33,7 @@ suites = {
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
180
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
180
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_transformers_models.py"
,
320
),
TestFile
(
"models/test_transformers_models.py"
,
320
),
TestFile
(
"models/test_vlm_models.py"
,
741
),
TestFile
(
"models/test_vlm_models.py"
,
741
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment