Unverified Commit 85ef7f64 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

[FIX] fix incorrect output when enable both deepgemm and torch compile (#4359)


Co-authored-by: default avatarxuyongfei.xyf <xuyongfei.xyf@antgroup.com>
parent f1cf6eef
......@@ -22,7 +22,14 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
from sglang.srt.utils import (
direct_register_custom_op,
get_device_core_count,
get_device_name,
is_cuda,
is_hip,
supports_custom_op,
)
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
......@@ -36,6 +43,33 @@ logger = logging.getLogger(__name__)
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
A: torch.Tensor,
As: torch.Tensor,
B: torch.Tensor,
Bs: torch.Tensor,
C: torch.Tensor,
) -> None:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake(
A: torch.Tensor,
As: torch.Tensor,
B: torch.Tensor,
Bs: torch.Tensor,
C: torch.Tensor,
) -> None:
return
direct_register_custom_op(
op_name="deep_gemm_fp8_fp8_bf16_nt",
op_func=deep_gemm_fp8_fp8_bf16_nt,
mutates_args=["C"],
fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
)
@triton.jit
def _per_token_group_quant_fp8(
......@@ -728,7 +762,10 @@ def w8a8_block_fp8_matmul(
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment