Unverified Commit 04b35190 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add dsv3 fused a gemm to sgl-kernel (#7630)

parent 071a1f51
......@@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
......
import argparse
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
)
)
def benchmark(num_tokens, impl):
kHdIn = 7168
kHdOut = 2112
M, K, N = num_tokens, kHdIn, kHdOut
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
def runner():
F.linear(mat_a, mat_b.T)
elif impl == "sgl-kernel":
def runner():
dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
......@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
......
This diff is collapsed.
......@@ -201,6 +201,8 @@ void bmm_fp8(
int64_t cublas_handle,
int64_t cuda_stream);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
/*
* From csrc/moe
*/
......
......@@ -241,6 +241,23 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor;
}
inline bool getBoolEnv(char const* name) {
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
});
return enablePDL;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
......
......@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,
......
......@@ -82,6 +82,21 @@ def bmm_fp8(
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
......
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_fused_a_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_fused_a_gemm(num_tokens):
kHdIn = 7168
kHdOut = 2112
mat_a = torch.randn(
(num_tokens, kHdIn), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn((kHdOut, kHdIn), dtype=torch.bfloat16, device="cuda").transpose(
0, 1
)
output = torch.empty(
(num_tokens, kHdOut), dtype=torch.bfloat16, device="cuda"
).contiguous()
ref = F.linear(mat_a, mat_b.T)
output = dsv3_fused_a_gemm(mat_a, mat_b)
assert torch.allclose(
output, ref, rtol=1e-2, atol=1e-3
), "Fused GEMM output mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment