Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0103f374
Unverified
Commit
0103f374
authored
Oct 26, 2025
by
fzyzcjy
Committed by
GitHub
Oct 26, 2025
Browse files
Support DeepGEMM for deterministic inference (#12142)
parent
96a5a949
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
1 deletion
+82
-1
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+65
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+8
-0
test/srt/batch_invariant/test_batch_invariant_ops.py
test/srt/batch_invariant/test_batch_invariant_ops.py
+9
-0
No files found.
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
0103f374
...
...
@@ -9,6 +9,22 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.deep_gemm_wrapper.configurer
import
ENABLE_JIT_DEEPGEMM
from
sglang.srt.utils.common
import
calc_diff
,
get_bool_env_var
if
ENABLE_JIT_DEEPGEMM
:
import
deep_gemm
_ENABLE_MM_DEEPGEMM
=
get_bool_env_var
(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM"
,
"1"
)
_ENABLE_MM_COMPARISON_TEST
=
get_bool_env_var
(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
)
if
not
_ENABLE_MM_DEEPGEMM
:
print
(
"Disable DeepGEMM in batch invariant ops. Performance may be suboptimal."
)
__all__
=
[
"set_batch_invariant_mode"
,
"is_batch_invariant_mode_enabled"
,
...
...
@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
matmul_persistent
(
def
_
matmul_persistent
_triton
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
# Check constraints.
...
...
@@ -217,6 +233,54 @@ def matmul_persistent(
return
c
def
_matmul_persistent_deepgemm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
M
,
K
=
a
.
shape
K
,
N
=
b
.
shape
dtype
=
a
.
dtype
out
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
dtype
)
deep_gemm
.
bf16_gemm_nn
(
a
,
b
,
out
)
# TODO can this be put in DeepGEMM's `c`?
if
bias
is
not
None
:
out
+=
bias
return
out
def
matmul_persistent
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
if
(
_ENABLE_MM_DEEPGEMM
and
ENABLE_JIT_DEEPGEMM
and
(
a
.
dtype
==
torch
.
bfloat16
)
and
(
b
.
dtype
==
torch
.
bfloat16
)
and
a
.
is_contiguous
()
and
b
.
transpose
(
0
,
1
).
is_contiguous
()
):
if
_ENABLE_MM_COMPARISON_TEST
:
out_triton
=
_matmul_persistent_triton
(
a
=
a
,
b
=
b
,
bias
=
bias
)
out_deepgemm
=
_matmul_persistent_deepgemm
(
a
=
a
,
b
=
b
,
bias
=
bias
)
diff
=
calc_diff
(
out_triton
,
out_deepgemm
)
assert
diff
<
0.0001
,
f
"
{
diff
=
}
{
out_triton
=
}
{
out_deepgemm
=
}
"
# can be enabled for debugging
# print(
# f"{diff=} "
# f"{(out_triton - out_deepgemm).abs().mean()=} "
# f"{(out_triton - out_deepgemm).abs().sum()=} "
# f"{torch.sum(out_triton != out_deepgemm)=} "
# )
# print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
return
out_deepgemm
return
_matmul_persistent_deepgemm
(
a
=
a
,
b
=
b
,
bias
=
bias
)
return
_matmul_persistent_triton
(
a
=
a
,
b
=
b
,
bias
=
bias
)
@
triton
.
jit
def
_log_softmax_kernel
(
input_ptr
,
...
...
python/sglang/srt/utils/common.py
View file @
0103f374
...
...
@@ -3565,3 +3565,11 @@ def cached_triton_kernel(key_fn=None):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
test/srt/batch_invariant/test_batch_invariant_ops.py
View file @
0103f374
...
...
@@ -4,6 +4,7 @@ import unittest
import
torch
from
sglang.srt.batch_invariant_ops
import
batch_invariant_ops
from
sglang.srt.batch_invariant_ops.batch_invariant_ops
import
set_batch_invariant_mode
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -16,6 +17,14 @@ with set_batch_invariant_mode(True):
class
TestBatchInvariantOps
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
batch_invariant_ops
.
_ENABLE_MM_COMPARISON_TEST
=
True
@
classmethod
def
tearDownClass
(
cls
):
batch_invariant_ops
.
_ENABLE_MM_COMPARISON_TEST
=
False
def
_test_batch_invariance
(
self
,
M
,
K
,
N
,
dtype
):
"""
Test that matrix operations produce identical results for:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment