Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0103f374
"vscode:/vscode.git/clone" did not exist on "9461b915c224f202cea47256a823eeb97a687af1"
Unverified
Commit
0103f374
authored
Oct 26, 2025
by
fzyzcjy
Committed by
GitHub
Oct 26, 2025
Browse files
Support DeepGEMM for deterministic inference (#12142)
parent
96a5a949
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
1 deletion
+82
-1
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+65
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+8
-0
test/srt/batch_invariant/test_batch_invariant_ops.py
test/srt/batch_invariant/test_batch_invariant_ops.py
+9
-0
No files found.
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
View file @
0103f374
...
...
@@ -9,6 +9,22 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.deep_gemm_wrapper.configurer
import
ENABLE_JIT_DEEPGEMM
from
sglang.srt.utils.common
import
calc_diff
,
get_bool_env_var
if
ENABLE_JIT_DEEPGEMM
:
import
deep_gemm
_ENABLE_MM_DEEPGEMM
=
get_bool_env_var
(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM"
,
"1"
)
_ENABLE_MM_COMPARISON_TEST
=
get_bool_env_var
(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
)
if
not
_ENABLE_MM_DEEPGEMM
:
print
(
"Disable DeepGEMM in batch invariant ops. Performance may be suboptimal."
)
__all__
=
[
"set_batch_invariant_mode"
,
"is_batch_invariant_mode_enabled"
,
...
...
@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
matmul_persistent
(
def
_
matmul_persistent
_triton
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
# Check constraints.
...
...
@@ -217,6 +233,54 @@ def matmul_persistent(
return
c
def
_matmul_persistent_deepgemm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
M
,
K
=
a
.
shape
K
,
N
=
b
.
shape
dtype
=
a
.
dtype
out
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
dtype
)
deep_gemm
.
bf16_gemm_nn
(
a
,
b
,
out
)
# TODO can this be put in DeepGEMM's `c`?
if
bias
is
not
None
:
out
+=
bias
return
out
def
matmul_persistent
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
if
(
_ENABLE_MM_DEEPGEMM
and
ENABLE_JIT_DEEPGEMM
and
(
a
.
dtype
==
torch
.
bfloat16
)
and
(
b
.
dtype
==
torch
.
bfloat16
)
and
a
.
is_contiguous
()
and
b
.
transpose
(
0
,
1
).
is_contiguous
()
):
if
_ENABLE_MM_COMPARISON_TEST
:
out_triton
=
_matmul_persistent_triton
(
a
=
a
,
b
=
b
,
bias
=
bias
)
out_deepgemm
=
_matmul_persistent_deepgemm
(
a
=
a
,
b
=
b
,
bias
=
bias
)
diff
=
calc_diff
(
out_triton
,
out_deepgemm
)
assert
diff
<
0.0001
,
f
"
{
diff
=
}
{
out_triton
=
}
{
out_deepgemm
=
}
"
# can be enabled for debugging
# print(
# f"{diff=} "
# f"{(out_triton - out_deepgemm).abs().mean()=} "
# f"{(out_triton - out_deepgemm).abs().sum()=} "
# f"{torch.sum(out_triton != out_deepgemm)=} "
# )
# print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
return
out_deepgemm
return
_matmul_persistent_deepgemm
(
a
=
a
,
b
=
b
,
bias
=
bias
)
return
_matmul_persistent_triton
(
a
=
a
,
b
=
b
,
bias
=
bias
)
@
triton
.
jit
def
_log_softmax_kernel
(
input_ptr
,
...
...
python/sglang/srt/utils/common.py
View file @
0103f374
...
...
@@ -3565,3 +3565,11 @@ def cached_triton_kernel(key_fn=None):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
test/srt/batch_invariant/test_batch_invariant_ops.py
View file @
0103f374
...
...
@@ -4,6 +4,7 @@ import unittest
import
torch
from
sglang.srt.batch_invariant_ops
import
batch_invariant_ops
from
sglang.srt.batch_invariant_ops.batch_invariant_ops
import
set_batch_invariant_mode
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -16,6 +17,14 @@ with set_batch_invariant_mode(True):
class
TestBatchInvariantOps
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
batch_invariant_ops
.
_ENABLE_MM_COMPARISON_TEST
=
True
@
classmethod
def
tearDownClass
(
cls
):
batch_invariant_ops
.
_ENABLE_MM_COMPARISON_TEST
=
False
def
_test_batch_invariance
(
self
,
M
,
K
,
N
,
dtype
):
"""
Test that matrix operations produce identical results for:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment