Unverified Commit 0103f374 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support DeepGEMM for deterministic inference (#12142)

parent 96a5a949
...@@ -9,6 +9,22 @@ import torch ...@@ -9,6 +9,22 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.utils.common import calc_diff, get_bool_env_var
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
_ENABLE_MM_DEEPGEMM = get_bool_env_var(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
)
_ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
)
if not _ENABLE_MM_DEEPGEMM:
print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
__all__ = [ __all__ = [
"set_batch_invariant_mode", "set_batch_invariant_mode",
"is_batch_invariant_mode_enabled", "is_batch_invariant_mode_enabled",
...@@ -140,7 +156,7 @@ def matmul_kernel_persistent( ...@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
def matmul_persistent( def _matmul_persistent_triton(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
): ):
# Check constraints. # Check constraints.
...@@ -217,6 +233,54 @@ def matmul_persistent( ...@@ -217,6 +233,54 @@ def matmul_persistent(
return c return c
def _matmul_persistent_deepgemm(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
):
M, K = a.shape
K, N = b.shape
dtype = a.dtype
out = torch.empty((M, N), device=a.device, dtype=dtype)
deep_gemm.bf16_gemm_nn(a, b, out)
# TODO can this be put in DeepGEMM's `c`?
if bias is not None:
out += bias
return out
def matmul_persistent(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
):
if (
_ENABLE_MM_DEEPGEMM
and ENABLE_JIT_DEEPGEMM
and (a.dtype == torch.bfloat16)
and (b.dtype == torch.bfloat16)
and a.is_contiguous()
and b.transpose(0, 1).is_contiguous()
):
if _ENABLE_MM_COMPARISON_TEST:
out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
diff = calc_diff(out_triton, out_deepgemm)
assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
# can be enabled for debugging
# print(
# f"{diff=} "
# f"{(out_triton - out_deepgemm).abs().mean()=} "
# f"{(out_triton - out_deepgemm).abs().sum()=} "
# f"{torch.sum(out_triton != out_deepgemm)=} "
# )
# print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
return out_deepgemm
return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
return _matmul_persistent_triton(a=a, b=b, bias=bias)
@triton.jit @triton.jit
def _log_softmax_kernel( def _log_softmax_kernel(
input_ptr, input_ptr,
......
...@@ -3565,3 +3565,11 @@ def cached_triton_kernel(key_fn=None): ...@@ -3565,3 +3565,11 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn) return CachedKernel(fn, key_fn)
return decorator return decorator
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
import torch import torch
from sglang.srt.batch_invariant_ops import batch_invariant_ops
from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -16,6 +17,14 @@ with set_batch_invariant_mode(True): ...@@ -16,6 +17,14 @@ with set_batch_invariant_mode(True):
class TestBatchInvariantOps(CustomTestCase): class TestBatchInvariantOps(CustomTestCase):
@classmethod
def setUpClass(cls):
batch_invariant_ops._ENABLE_MM_COMPARISON_TEST = True
@classmethod
def tearDownClass(cls):
batch_invariant_ops._ENABLE_MM_COMPARISON_TEST = False
def _test_batch_invariance(self, M, K, N, dtype): def _test_batch_invariance(self, M, K, N, dtype):
""" """
Test that matrix operations produce identical results for: Test that matrix operations produce identical results for:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment