Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eae9a9fb
"docs/code-docs/vscode:/vscode.git/clone" did not exist on "6380ee35116dfc7c2a037d48fbd790d8dbabfc1d"
Unverified
Commit
eae9a9fb
authored
Oct 10, 2025
by
Stefan He
Committed by
GitHub
Oct 10, 2025
Browse files
Fix batch invariant ops (#11368)
parent
2674c1d2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
6 deletions
+168
-6
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+4
-6
test/srt/batch_invariant/test_batch_invariant_ops.py
test/srt/batch_invariant/test_batch_invariant_ops.py
+163
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
eae9a9fb
...
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
...
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
num_tiles
=
num_pid_m
*
num_pid_n
num_tiles
=
num_pid_m
*
num_pid_n
tile_id_c
=
start_pid
-
NUM_SMS
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
...
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
...
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
)
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
tile_id_c
+=
NUM_SMS
pid_m
,
pid_n
=
_compute_pid
(
tile_id_c
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
C_LARGE
:
if
C_LARGE
:
...
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
...
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
accumulator
+=
bias
accumulator
+=
bias
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
float32
:
c
=
accumulator
.
to
(
tl
.
float32
)
else
:
else
:
c
=
accumulator
.
to
(
tl
.
float16
)
c
=
accumulator
.
to
(
tl
.
float16
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
...
...
test/srt/batch_invariant/test_batch_invariant_ops.py
0 → 100644
View file @
eae9a9fb
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
import
math
import
unittest
import
torch
from
sglang.srt.batch_invariant_ops.batch_invariant_ops
import
set_batch_invariant_mode
from
sglang.test.test_utils
import
CustomTestCase
device_type
=
getattr
(
torch
.
accelerator
.
current_accelerator
(),
"type"
,
"cpu"
)
torch
.
set_default_device
(
device_type
)
# Just to get the logging out of the way
with
set_batch_invariant_mode
(
True
):
pass
class
TestBatchInvariantOps
(
CustomTestCase
):
def
_test_batch_invariance
(
self
,
M
,
K
,
N
,
dtype
):
"""
Test that matrix operations produce identical results for:
- Method 1: Matrix-vector multiplication (batch size 1)
- Method 2: Matrix-matrix multiplication, then slice (full batch)
"""
a
=
torch
.
linspace
(
-
100
,
100
,
M
*
K
,
dtype
=
dtype
).
reshape
(
M
,
K
)
# Create non-contiguous tensor
b
=
torch
.
linspace
(
-
100
,
100
,
K
*
N
,
dtype
=
dtype
).
reshape
(
N
,
K
)
b
=
b
.
transpose
(
0
,
1
)
# Method 1: Matrix-vector multiplication (batch size 1)
out1
=
torch
.
mm
(
a
[:
1
],
b
)
# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2_pre
=
torch
.
mm
(
a
,
b
)
out2
=
out2_pre
[:
1
]
# Check if results are identical
diff
=
(
out1
-
out2
).
abs
().
max
()
return
diff
.
item
()
def
_run_multiple_iterations
(
self
,
iters
,
M
,
K
,
N
,
dtype
):
"""Run multiple iterations and collect diff statistics"""
difflist
=
[]
for
_
in
range
(
iters
):
diff
=
self
.
_test_batch_invariance
(
M
,
K
,
N
,
dtype
)
difflist
.
append
(
diff
)
return
difflist
def
_assert_batch_invariant_results
(
self
,
difflist
,
dtype
,
test_name
):
"""
Assert that in batch-invariant mode:
1. All diffs must not be NaN
2. All diffs must be exactly 0
3. Max, min, and diff of diffs must all be 0
"""
max_diff
=
max
(
difflist
)
min_diff
=
min
(
difflist
)
diff_range
=
max_diff
-
min_diff
# Check for NaN values
self
.
assertFalse
(
math
.
isnan
(
max_diff
),
f
"
{
test_name
}
: max_diff is NaN for
{
dtype
}
"
)
self
.
assertFalse
(
math
.
isnan
(
min_diff
),
f
"
{
test_name
}
: min_diff is NaN for
{
dtype
}
"
)
self
.
assertFalse
(
math
.
isnan
(
diff_range
),
f
"
{
test_name
}
: diff_range is NaN for
{
dtype
}
"
)
# Check that all diffs are exactly 0
self
.
assertEqual
(
max_diff
,
0.0
,
f
"
{
test_name
}
: max_diff must be 0 in batch-invariant mode, got
{
max_diff
}
for
{
dtype
}
"
,
)
self
.
assertEqual
(
min_diff
,
0.0
,
f
"
{
test_name
}
: min_diff must be 0 in batch-invariant mode, got
{
min_diff
}
for
{
dtype
}
"
,
)
self
.
assertEqual
(
diff_range
,
0.0
,
f
"
{
test_name
}
: diff_range must be 0 in batch-invariant mode, got
{
diff_range
}
for
{
dtype
}
"
,
)
def
test_small_matrices
(
self
):
"""Test batch invariance with small matrix sizes"""
test_cases
=
[
(
"Small-1"
,
8
,
64
,
128
),
(
"Small-2"
,
16
,
128
,
256
),
(
"Small-3"
,
4
,
32
,
64
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_medium_matrices
(
self
):
"""Test batch invariance with medium matrix sizes"""
test_cases
=
[
(
"Medium-1"
,
32
,
128
,
1024
),
(
"Medium-2"
,
64
,
512
,
2048
),
(
"Medium-3"
,
24
,
192
,
768
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_large_matrices
(
self
):
"""Test batch invariance with large matrix sizes"""
test_cases
=
[
(
"Large-1"
,
128
,
1024
,
4096
),
(
"Large-2"
,
256
,
2048
,
8192
),
(
"Large-3"
,
96
,
768
,
3072
),
]
for
name
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_without_batch_invariant_mode
(
self
):
"""
Test that without batch-invariant mode, results may differ.
This test demonstrates the difference batch-invariant mode makes.
"""
M
,
K
,
N
=
32
,
128
,
1024
dtype
=
torch
.
float32
# Run without batch-invariant mode
with
set_batch_invariant_mode
(
False
):
difflist
=
self
.
_run_multiple_iterations
(
iters
=
5
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
print
(
f
"Without batch-invariant mode, we get diffs:
{
difflist
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
eae9a9fb
...
@@ -33,6 +33,7 @@ suites = {
...
@@ -33,6 +33,7 @@ suites = {
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
180
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
180
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_transformers_models.py"
,
320
),
TestFile
(
"models/test_transformers_models.py"
,
320
),
TestFile
(
"models/test_vlm_models.py"
,
741
),
TestFile
(
"models/test_vlm_models.py"
,
741
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment