Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c14cc47e
"examples/pytorch/vscode:/vscode.git/clone" did not exist on "9fee20b91dba3804a7a19f327b2b1d2407d93874"
Unverified
Commit
c14cc47e
authored
Nov 04, 2025
by
Minglei Zhu
Committed by
GitHub
Nov 04, 2025
Browse files
[Deterministic] Optimize bmm_batch_invariant op (#12522)
parent
dbcf85b7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
292 additions
and
10 deletions
+292
-10
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+206
-10
test/srt/batch_invariant/test_batch_invariant_ops.py
test/srt/batch_invariant/test_batch_invariant_ops.py
+86
-0
No files found.
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
c14cc47e
...
@@ -559,19 +559,215 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
...
@@ -559,19 +559,215 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return
torch
.
sum
(
input
,
dim
=
dim
,
keepdim
=
keepdim
,
dtype
=
torch
.
float32
)
/
n_elems
return
torch
.
sum
(
input
,
dim
=
dim
,
keepdim
=
keepdim
,
dtype
=
torch
.
float32
)
/
n_elems
@
triton
.
jit
def
bmm_kernel_persistent
(
a_ptr
,
b_ptr
,
c_ptr
,
#
B
,
M
,
N
,
K
,
#
stride_ab
,
stride_am
,
stride_ak
,
stride_bb
,
stride_bk
,
stride_bn
,
stride_cb
,
stride_cm
,
stride_cn
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
#
BLOCK_SIZE_N
:
tl
.
constexpr
,
#
BLOCK_SIZE_K
:
tl
.
constexpr
,
#
GROUP_SIZE_M
:
tl
.
constexpr
,
#
NUM_SMS
:
tl
.
constexpr
,
#
A_LARGE
:
tl
.
constexpr
,
B_LARGE
:
tl
.
constexpr
,
C_LARGE
:
tl
.
constexpr
,
):
"""
Batched matrix multiplication kernel that processes batches in parallel.
Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch.
"""
start_pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
num_tiles_per_batch
=
num_pid_m
*
num_pid_n
num_tiles_total
=
B
*
num_tiles_per_batch
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
# Process tiles in a deterministic order: batch-major ordering
for
tile_id
in
tl
.
range
(
start_pid
,
num_tiles_total
,
NUM_SMS
,
flatten
=
True
):
# Decompose tile_id into batch and within-batch tile
batch_idx
=
tile_id
//
num_tiles_per_batch
tile_in_batch
=
tile_id
%
num_tiles_per_batch
pid_m
,
pid_n
=
_compute_pid
(
tile_in_batch
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
start_m
=
pid_m
*
BLOCK_SIZE_M
start_n
=
pid_n
*
BLOCK_SIZE_N
offs_am
=
start_m
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_bn
=
start_n
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
A_LARGE
:
offs_am
=
offs_am
.
to
(
tl
.
int64
)
if
B_LARGE
:
offs_bn
=
offs_bn
.
to
(
tl
.
int64
)
offs_am
=
tl
.
where
(
offs_am
<
M
,
offs_am
,
0
)
offs_bn
=
tl
.
where
(
offs_bn
<
N
,
offs_bn
,
0
)
offs_am
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_am
,
BLOCK_SIZE_M
),
BLOCK_SIZE_M
)
offs_bn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_bn
,
BLOCK_SIZE_N
),
BLOCK_SIZE_N
)
# Add batch offset
if
A_LARGE
or
B_LARGE
:
batch_idx_typed
=
batch_idx
.
to
(
tl
.
int64
)
else
:
batch_idx_typed
=
batch_idx
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
ki
in
range
(
k_tiles
):
if
A_LARGE
or
B_LARGE
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
).
to
(
tl
.
int64
)
else
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
batch_idx_typed
*
stride_ab
+
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
(
batch_idx_typed
*
stride_bb
+
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k_for_mask
[
None
,
:]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k_for_mask
[:,
None
]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
C_LARGE
:
offs_cm
=
offs_cm
.
to
(
tl
.
int64
)
offs_cn
=
offs_cn
.
to
(
tl
.
int64
)
c_ptrs
=
(
c_ptr
+
batch_idx_typed
*
stride_cb
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
)
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
c_ptr
.
dtype
.
element_ty
==
tl
.
float32
:
c
=
accumulator
.
to
(
tl
.
float32
)
else
:
c
=
accumulator
.
to
(
tl
.
float16
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
bmm_batch_invariant
(
a
,
b
,
*
,
out
=
None
):
def
bmm_batch_invariant
(
a
,
b
,
*
,
out
=
None
):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process
each batch se
para
t
el
y
with our persistent kernel
# Process
batches in
para
ll
el with our persistent kernel
if
a
.
ndim
==
3
and
b
.
ndim
==
3
:
if
a
.
ndim
==
3
and
b
.
ndim
==
3
:
results
=
[]
# Check constraints
for
i
in
range
(
a
.
shape
[
0
]):
assert
a
.
shape
[
0
]
==
b
.
shape
[
0
],
"Batch sizes must match"
results
.
append
(
matmul_persistent
(
a
[
i
],
b
[
i
]))
assert
a
.
shape
[
2
]
==
b
.
shape
[
1
],
"Incompatible dimensions"
result
=
torch
.
stack
(
results
,
dim
=
0
)
assert
a
.
dtype
==
b
.
dtype
,
"Incompatible dtypes"
if
out
is
not
None
:
B
=
a
.
shape
[
0
]
out
.
copy_
(
result
)
M
=
a
.
shape
[
1
]
return
out
K
=
a
.
shape
[
2
]
return
result
N
=
b
.
shape
[
2
]
dtype
=
a
.
dtype
# Allocate output
if
out
is
None
:
c
=
torch
.
empty
((
B
,
M
,
N
),
device
=
a
.
device
,
dtype
=
dtype
)
else
:
c
=
out
NUM_SMS
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
multi_processor_count
# Use fixed kernel configuration for determinism
configs
=
{
torch
.
bfloat16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float32
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
}
config
=
configs
.
get
(
dtype
)
if
config
is
None
:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
for bmm_batch_invariant. "
f
"Supported dtypes are:
{
list
(
configs
.
keys
())
}
"
)
# Grid: limit by NUM_SMS for persistent kernel approach
num_tiles_per_batch
=
triton
.
cdiv
(
M
,
config
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
config
[
"BLOCK_SIZE_N"
]
)
num_tiles_total
=
B
*
num_tiles_per_batch
grid
=
(
min
(
NUM_SMS
,
num_tiles_total
),)
bmm_kernel_persistent
[
grid
](
a
,
b
,
c
,
#
B
,
M
,
N
,
K
,
#
a
.
stride
(
0
),
a
.
stride
(
1
),
a
.
stride
(
2
),
#
b
.
stride
(
0
),
b
.
stride
(
1
),
b
.
stride
(
2
),
#
c
.
stride
(
0
),
c
.
stride
(
1
),
c
.
stride
(
2
),
#
NUM_SMS
=
NUM_SMS
,
#
A_LARGE
=
a
.
numel
()
>
2
**
31
,
B_LARGE
=
b
.
numel
()
>
2
**
31
,
C_LARGE
=
c
.
numel
()
>
2
**
31
,
**
config
,
)
return
c
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"bmm_batch_invariant expects 3D tensors, "
f
"bmm_batch_invariant expects 3D tensors, "
...
...
test/srt/batch_invariant/test_batch_invariant_ops.py
View file @
c14cc47e
...
@@ -167,6 +167,92 @@ class TestBatchInvariantOps(CustomTestCase):
...
@@ -167,6 +167,92 @@ class TestBatchInvariantOps(CustomTestCase):
)
)
print
(
f
"Without batch-invariant mode, we get diffs:
{
difflist
}
"
)
print
(
f
"Without batch-invariant mode, we get diffs:
{
difflist
}
"
)
def
_test_bmm_batch_invariance
(
self
,
B
,
M
,
K
,
N
,
dtype
):
"""
Test that BMM operations produce identical results for:
- Method 1: BMM with subset of batches
- Method 2: BMM with all batches, then slice
"""
a
=
torch
.
linspace
(
-
100
,
100
,
B
*
M
*
K
,
dtype
=
dtype
).
reshape
(
B
,
M
,
K
)
b
=
torch
.
linspace
(
-
100
,
100
,
B
*
K
*
N
,
dtype
=
dtype
).
reshape
(
B
,
K
,
N
)
# Method 1: BMM with subset (first 2 batches)
subset_size
=
min
(
2
,
B
)
out1
=
torch
.
bmm
(
a
[:
subset_size
],
b
[:
subset_size
])
# Method 2: BMM with all batches, then slice
out2_pre
=
torch
.
bmm
(
a
,
b
)
out2
=
out2_pre
[:
subset_size
]
# Check if results are identical
diff
=
(
out1
-
out2
).
abs
().
max
()
return
diff
.
item
()
def
_run_bmm_multiple_iterations
(
self
,
iters
,
B
,
M
,
K
,
N
,
dtype
):
"""Run multiple BMM iterations and collect diff statistics"""
difflist
=
[]
for
_
in
range
(
iters
):
diff
=
self
.
_test_bmm_batch_invariance
(
B
,
M
,
K
,
N
,
dtype
)
difflist
.
append
(
diff
)
return
difflist
def
test_bmm_small_matrices
(
self
):
"""Test BMM batch invariance with small matrix sizes"""
test_cases
=
[
(
"BMM-Small-1"
,
4
,
8
,
64
,
128
),
(
"BMM-Small-2"
,
8
,
16
,
128
,
256
),
(
"BMM-Small-3"
,
6
,
4
,
32
,
64
),
]
for
name
,
B
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_bmm_multiple_iterations
(
iters
=
5
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_bmm_medium_matrices
(
self
):
"""Test BMM batch invariance with medium matrix sizes"""
test_cases
=
[
(
"BMM-Medium-1"
,
8
,
32
,
128
,
1024
),
(
"BMM-Medium-2"
,
16
,
64
,
512
,
2048
),
(
"BMM-Medium-3"
,
12
,
24
,
192
,
768
),
]
for
name
,
B
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_bmm_multiple_iterations
(
iters
=
5
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
def
test_bmm_large_matrices
(
self
):
"""Test BMM batch invariance with large matrix sizes"""
test_cases
=
[
(
"BMM-Large-1"
,
16
,
128
,
1024
,
4096
),
(
"BMM-Large-2"
,
32
,
256
,
2048
,
8192
),
(
"BMM-Large-3"
,
24
,
96
,
768
,
3072
),
]
for
name
,
B
,
M
,
K
,
N
in
test_cases
:
with
self
.
subTest
(
name
=
name
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
):
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
]:
with
self
.
subTest
(
dtype
=
dtype
):
# Run with batch-invariant mode
with
set_batch_invariant_mode
(
True
):
difflist
=
self
.
_run_bmm_multiple_iterations
(
iters
=
5
,
B
=
B
,
M
=
M
,
K
=
K
,
N
=
N
,
dtype
=
dtype
)
self
.
_assert_batch_invariant_results
(
difflist
,
dtype
,
name
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment