Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
282eb59f
Unverified
Commit
282eb59f
authored
Jul 19, 2025
by
Baizhou Zhang
Committed by
GitHub
Jul 20, 2025
Browse files
Add bf16 output option for dsv3_router_gemm kernel (#7999)
parent
4540a466
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
465 additions
and
104 deletions
+465
-104
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+3
-1
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
+50
-3
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
+234
-0
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
+127
-0
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
+39
-92
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+2
-1
sgl-kernel/tests/test_dsv3_router_gemm.py
sgl-kernel/tests/test_dsv3_router_gemm.py
+10
-7
No files found.
sgl-kernel/CMakeLists.txt
View file @
282eb59f
...
...
@@ -222,7 +222,9 @@ set(SOURCES
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm.cu"
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
"csrc/gemm/dsv3_router_gemm_entry.cu"
"csrc/gemm/dsv3_router_gemm_float_out.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
View file @
282eb59f
...
...
@@ -7,6 +7,48 @@ import triton.testing
from
sgl_kernel
import
dsv3_router_gemm
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch"
,
"dsv3_router_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-bf16 dsv3 router gemm throughput"
,
args
=
{},
)
)
def
benchmark_bf16_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
)
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
bfloat16
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
M
*
K
*
N
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
...
...
@@ -21,7 +63,7 @@ from sgl_kernel import dsv3_router_gemm
args
=
{},
)
)
def
benchmark
(
num_tokens
,
impl
):
def
benchmark
_float_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
...
...
@@ -38,7 +80,7 @@ def benchmark(num_tokens, impl):
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
)
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
float32
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
...
...
@@ -53,4 +95,9 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
benchmark_bf16_output
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
benchmark_float_output
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
0 → 100644
View file @
282eb59f
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__
__forceinline__
void
fma
(
float2
&
d
,
float2
const
&
a
,
float2
const
&
b
,
float2
const
&
c
)
{
asm
volatile
(
"fma.rn.f32x2 %0, %1, %2, %3;
\n
"
:
"=l"
(
reinterpret_cast
<
uint64_t
&>
(
d
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
a
)),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
b
)),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
c
)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template
<
int
VPT
>
__device__
__forceinline__
void
bf16_uint4_to_float8
(
uint4
const
&
vec
,
float
*
dst
)
{
__nv_bfloat16
*
bf16_ptr
=
reinterpret_cast
<
__nv_bfloat16
*>
(
const_cast
<
uint4
*>
(
&
vec
));
#pragma unroll
for
(
int
i
=
0
;
i
<
VPT
;
i
++
)
{
dst
[
i
]
=
__bfloat162float
(
bf16_ptr
[
i
]);
}
}
template
<
typename
T
,
int
kBlockSize
,
int
VPT
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel_bf16_output
(
__nv_bfloat16
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
// Each block handles one expert column
int
const
n_idx
=
blockIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
constexpr
int
kNumWarps
=
kBlockSize
/
kWarpSize
;
// Constants for this kernel
constexpr
int
k_elems_per_k_iteration
=
VPT
*
kBlockSize
;
constexpr
int
k_iterations
=
kHiddenDim
/
k_elems_per_k_iteration
;
// Total K iterations
// Initialize accumulators for all M rows
float
acc
[
kNumTokens
]
=
{};
// Shared memory for warp-level reduction
__shared__
float
sm_reduction
[
kNumTokens
][
kNumWarps
];
// kNumWarps
// B matrix is in column-major order, so we can directly load a column for the n_idx expert
T
const
*
b_col
=
mat_b
+
n_idx
*
kHiddenDim
;
// Pre-compute k_base values for each iteration to help compiler optimize
// int k_bases[k_iterations];
int
k_bases
[
k_iterations
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
k_iterations
;
ki
++
)
{
k_bases
[
ki
]
=
ki
*
k_elems_per_k_iteration
+
tid
*
VPT
;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
// Process the GEMM in chunks
for
(
int
ki
=
0
;
ki
<
k_iterations
;
ki
++
)
{
int
const
k_base
=
k_bases
[
ki
];
// Load B matrix values using vector load (8 bf16 values)
uint4
b_vec
=
*
reinterpret_cast
<
uint4
const
*>
(
b_col
+
k_base
);
// Convert B values to float
float
b_float
[
VPT
];
bf16_uint4_to_float8
<
VPT
>
(
b_vec
,
b_float
);
// Process each token
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
kNumTokens
;
m_idx
++
)
{
// Load both rows of A matrix using vector loads
uint4
a_vec
=
*
reinterpret_cast
<
uint4
const
*>
(
mat_a
+
(
m_idx
*
kHiddenDim
)
+
k_base
);
// Convert A values to float
float
a_float
[
VPT
];
bf16_uint4_to_float8
<
VPT
>
(
a_vec
,
a_float
);
// Process elements in this chunk
#pragma unroll
for
(
int
k
=
0
;
k
<
VPT
;
k
++
)
{
float
a
=
a_float
[
k
];
float
b
=
b_float
[
k
];
acc
[
m_idx
]
+=
a
*
b
;
}
}
}
// Perform warp-level reduction
int
const
warpSize
=
32
;
int
const
warpId
=
tid
/
warpSize
;
int
const
laneId
=
tid
%
warpSize
;
// Register for warp-level reduction results
float
warp_result
[
kNumTokens
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
kNumTokens
;
m_idx
++
)
{
warp_result
[
m_idx
]
=
acc
[
m_idx
];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for
(
int
m
=
0
;
m
<
kNumTokens
;
m
++
)
{
float
sum
=
warp_result
[
m
];
// Butterfly reduction pattern
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
16
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
8
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
4
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
2
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
1
);
// Only the first thread in each warp stores to shared memory
if
(
laneId
==
0
)
{
sm_reduction
[
m
][
warpId
]
=
sum
;
}
}
__syncthreads
();
// Final reduction across warps (only first thread)
if
(
tid
==
0
)
{
#pragma unroll
for
(
int
m
=
0
;
m
<
kNumTokens
;
m
++
)
{
float
final_sum
=
0.0
f
;
// Sum across the kNumWarps
#pragma unroll
for
(
int
w
=
0
;
w
<
kNumWarps
;
w
++
)
{
final_sum
+=
sm_reduction
[
m
][
w
];
}
// Write final result
out
[
m
*
kNumExperts
+
n_idx
]
=
__float2bfloat16
(
final_sum
);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmBf16Output
(
__nv_bfloat16
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
constexpr
int
VPT
=
16
/
sizeof
(
T
);
constexpr
int
kBlockSize
=
128
;
cudaLaunchConfig_t
config
;
config
.
gridDim
=
kNumExperts
;
config
.
blockDim
=
kBlockSize
;
config
.
dynamicSmemBytes
=
0
;
config
.
stream
=
stream
;
cudaLaunchAttribute
attrs
[
1
];
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
getEnvEnablePDL
();
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
router_gemm_kernel_bf16_output
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
}
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
1
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
2
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
3
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
4
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
5
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
6
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
7
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
8
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
9
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
10
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
11
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
12
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
13
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
14
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
15
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
16
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
0 → 100644
View file @
282eb59f
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmFloatOutput
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
);
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmBf16Output
(
__nv_bfloat16
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
);
template
<
int
kBegin
,
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
{
static
void
unroll_float_output
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll_float_output
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
static
void
unroll_bf16_output
(
int
num_tokens
,
__nv_bfloat16
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll_bf16_output
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
};
template
<
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
<
kEnd
,
kEnd
,
kNumExperts
,
kHiddenDim
>
{
static
void
unroll_float_output
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
static
void
unroll_bf16_output
(
int
num_tokens
,
__nv_bfloat16
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
};
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
// [num_tokens, num_experts]
const
torch
::
Tensor
&
mat_a
,
// [num_tokens, hidden_dim]
const
torch
::
Tensor
&
mat_b
// [num_experts, hidden_dim]
)
{
TORCH_CHECK
(
output
.
dim
()
==
2
&&
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
);
const
int
num_tokens
=
mat_a
.
size
(
0
);
constexpr
int
num_experts
=
256
;
constexpr
int
hidden_dim
=
7168
;
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
1
),
"mat_a and mat_b must have the same hidden_dim"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
hidden_dim
,
"currently hidden_dim only supports 7168"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
==
num_experts
,
"currently num_experts only supports 256"
);
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"currently num_tokens must be less than or equal to 16 for router_gemm"
);
TORCH_CHECK
(
mat_a
.
dtype
()
==
torch
::
kBFloat16
,
"mat_a must be bf16"
);
TORCH_CHECK
(
mat_b
.
dtype
()
==
torch
::
kBFloat16
,
"mat_b must be bf16"
);
TORCH_CHECK
(
output
.
dtype
()
==
torch
::
kFloat32
||
output
.
dtype
()
==
torch
::
kBFloat16
,
"output must be float32 or bf16"
);
auto
const
sm
=
getSMVersion
();
TORCH_CHECK
(
sm
>=
90
,
"required CUDA ARCH >= SM_90"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_float_output
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_bf16_output
(
num_tokens
,
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
}
sgl-kernel/csrc/gemm/dsv3_router_gemm.cu
→
sgl-kernel/csrc/gemm/dsv3_router_gemm
_float_out
.cu
View file @
282eb59f
...
...
@@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds
}
template
<
typename
T
,
int
kBlockSize
,
int
VPT
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel
(
float
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel
_float_output
(
float
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
// Each block handles one expert column
int
const
n_idx
=
blockIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
...
...
@@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const
}
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemm
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
void
invokeRouterGemm
FloatOutput
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
constexpr
int
VPT
=
16
/
sizeof
(
T
);
constexpr
int
kBlockSize
=
128
;
cudaLaunchConfig_t
config
;
...
...
@@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
router_gemm_kernel
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
&
config
,
router_gemm_kernel_float_output
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
}
template
void
invokeRouterGemm
<
__nv_bfloat16
,
1
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
1
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
2
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
2
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
3
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
3
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
4
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
4
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
5
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
5
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
6
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
6
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
7
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
7
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
8
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
8
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
9
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
9
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
10
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
10
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
11
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
11
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
12
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
12
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
13
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
13
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
14
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
14
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
15
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
15
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemm
<
__nv_bfloat16
,
16
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
<
int
kBegin
,
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
{
static
void
unroll
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemm
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
};
template
<
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
<
kEnd
,
kEnd
,
kNumExperts
,
kHiddenDim
>
{
static
void
unroll
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemm
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
};
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
// [num_tokens, num_experts]
const
torch
::
Tensor
&
mat_a
,
// [num_tokens, hidden_dim]
const
torch
::
Tensor
&
mat_b
// [num_experts, hidden_dim]
)
{
TORCH_CHECK
(
output
.
dim
()
==
2
&&
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
);
const
int
num_tokens
=
mat_a
.
size
(
0
);
constexpr
int
num_experts
=
256
;
constexpr
int
hidden_dim
=
7168
;
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
1
),
"mat_a and mat_b must have the same hidden_dim"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
hidden_dim
,
"currently hidden_dim only supports 7168"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
==
num_experts
,
"currently num_experts only supports 256"
);
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"currently num_tokens must be less than or equal to 16 for router_gemm"
);
TORCH_CHECK
(
mat_a
.
dtype
()
==
torch
::
kBFloat16
,
"mat_a must be bf16"
);
TORCH_CHECK
(
mat_b
.
dtype
()
==
torch
::
kBFloat16
,
"mat_b must be bf16"
);
TORCH_CHECK
(
output
.
dtype
()
==
torch
::
kFloat32
,
"output must be float32"
);
auto
const
sm
=
getSMVersion
();
TORCH_CHECK
(
sm
>=
90
,
"required CUDA ARCH >= SM_90"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
16
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
sgl-kernel/python/sgl_kernel/gemm.py
View file @
282eb59f
...
...
@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
def
dsv3_router_gemm
(
hidden_states
:
torch
.
Tensor
,
router_weights
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
router_weights
.
shape
[
0
],
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
dtype
=
out_dtype
,
)
torch
.
ops
.
sgl_kernel
.
dsv3_router_gemm
(
output
,
...
...
sgl-kernel/tests/test_dsv3_router_gemm.py
View file @
282eb59f
...
...
@@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens):
mat_b
=
torch
.
randn
(
(
num_experts
,
hidden_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
output
=
torch
.
empty
(
(
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
contiguous
()
ref
=
F
.
linear
(
mat_a
,
mat_b
).
to
(
torch
.
float32
)
bf16_ref
=
F
.
linear
(
mat_a
,
mat_b
)
float_ref
=
bf16_ref
.
to
(
torch
.
float32
)
bf16_output
=
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
bfloat16
)
float_output
=
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
float32
)
output
=
dsv3_router_gemm
(
mat_a
,
mat_b
)
assert
torch
.
allclose
(
bf16_output
,
bf16_ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert
torch
.
allclose
(
output
,
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Router GEMM output mismatch with torch.nn.functional.linear reference"
float_
output
,
float_
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Router GEMM output
in float32 dtype
mismatch with torch.nn.functional.linear reference"
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment