Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
282eb59f
"vscode:/vscode.git/clone" did not exist on "89e6521c611dca3d0eb062d1738e09eba9b5dc30"
Unverified
Commit
282eb59f
authored
Jul 19, 2025
by
Baizhou Zhang
Committed by
GitHub
Jul 20, 2025
Browse files
Add bf16 output option for dsv3_router_gemm kernel (#7999)
parent
4540a466
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
465 additions
and
104 deletions
+465
-104
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+3
-1
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
+50
-3
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
+234
-0
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
+127
-0
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
+39
-92
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+2
-1
sgl-kernel/tests/test_dsv3_router_gemm.py
sgl-kernel/tests/test_dsv3_router_gemm.py
+10
-7
No files found.
sgl-kernel/CMakeLists.txt
View file @
282eb59f
...
@@ -222,7 +222,9 @@ set(SOURCES
...
@@ -222,7 +222,9 @@ set(SOURCES
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm.cu"
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
"csrc/gemm/dsv3_router_gemm_entry.cu"
"csrc/gemm/dsv3_router_gemm_float_out.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
View file @
282eb59f
...
@@ -7,6 +7,48 @@ import triton.testing
...
@@ -7,6 +7,48 @@ import triton.testing
from
sgl_kernel
import
dsv3_router_gemm
from
sgl_kernel
import
dsv3_router_gemm
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch"
,
"dsv3_router_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-bf16 dsv3 router gemm throughput"
,
args
=
{},
)
)
def
benchmark_bf16_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
)
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
bfloat16
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
M
*
K
*
N
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_names
=
[
"num_tokens"
],
...
@@ -21,7 +63,7 @@ from sgl_kernel import dsv3_router_gemm
...
@@ -21,7 +63,7 @@ from sgl_kernel import dsv3_router_gemm
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
num_tokens
,
impl
):
def
benchmark
_float_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
M
,
K
,
N
=
num_tokens
,
7168
,
256
...
@@ -38,7 +80,7 @@ def benchmark(num_tokens, impl):
...
@@ -38,7 +80,7 @@ def benchmark(num_tokens, impl):
elif
impl
==
"sgl-kernel"
:
elif
impl
==
"sgl-kernel"
:
def
runner
():
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
)
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
float32
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
...
@@ -53,4 +95,9 @@ if __name__ == "__main__":
...
@@ -53,4 +95,9 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
benchmark_bf16_output
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
benchmark_float_output
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_router_gemm"
)
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
0 → 100644
View file @
282eb59f
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__
__forceinline__
void
fma
(
float2
&
d
,
float2
const
&
a
,
float2
const
&
b
,
float2
const
&
c
)
{
asm
volatile
(
"fma.rn.f32x2 %0, %1, %2, %3;
\n
"
:
"=l"
(
reinterpret_cast
<
uint64_t
&>
(
d
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
a
)),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
b
)),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
c
)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template
<
int
VPT
>
__device__
__forceinline__
void
bf16_uint4_to_float8
(
uint4
const
&
vec
,
float
*
dst
)
{
__nv_bfloat16
*
bf16_ptr
=
reinterpret_cast
<
__nv_bfloat16
*>
(
const_cast
<
uint4
*>
(
&
vec
));
#pragma unroll
for
(
int
i
=
0
;
i
<
VPT
;
i
++
)
{
dst
[
i
]
=
__bfloat162float
(
bf16_ptr
[
i
]);
}
}
template
<
typename
T
,
int
kBlockSize
,
int
VPT
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel_bf16_output
(
__nv_bfloat16
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
// Each block handles one expert column
int
const
n_idx
=
blockIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
constexpr
int
kWarpSize
=
32
;
constexpr
int
kNumWarps
=
kBlockSize
/
kWarpSize
;
// Constants for this kernel
constexpr
int
k_elems_per_k_iteration
=
VPT
*
kBlockSize
;
constexpr
int
k_iterations
=
kHiddenDim
/
k_elems_per_k_iteration
;
// Total K iterations
// Initialize accumulators for all M rows
float
acc
[
kNumTokens
]
=
{};
// Shared memory for warp-level reduction
__shared__
float
sm_reduction
[
kNumTokens
][
kNumWarps
];
// kNumWarps
// B matrix is in column-major order, so we can directly load a column for the n_idx expert
T
const
*
b_col
=
mat_b
+
n_idx
*
kHiddenDim
;
// Pre-compute k_base values for each iteration to help compiler optimize
// int k_bases[k_iterations];
int
k_bases
[
k_iterations
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
k_iterations
;
ki
++
)
{
k_bases
[
ki
]
=
ki
*
k_elems_per_k_iteration
+
tid
*
VPT
;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
// Process the GEMM in chunks
for
(
int
ki
=
0
;
ki
<
k_iterations
;
ki
++
)
{
int
const
k_base
=
k_bases
[
ki
];
// Load B matrix values using vector load (8 bf16 values)
uint4
b_vec
=
*
reinterpret_cast
<
uint4
const
*>
(
b_col
+
k_base
);
// Convert B values to float
float
b_float
[
VPT
];
bf16_uint4_to_float8
<
VPT
>
(
b_vec
,
b_float
);
// Process each token
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
kNumTokens
;
m_idx
++
)
{
// Load both rows of A matrix using vector loads
uint4
a_vec
=
*
reinterpret_cast
<
uint4
const
*>
(
mat_a
+
(
m_idx
*
kHiddenDim
)
+
k_base
);
// Convert A values to float
float
a_float
[
VPT
];
bf16_uint4_to_float8
<
VPT
>
(
a_vec
,
a_float
);
// Process elements in this chunk
#pragma unroll
for
(
int
k
=
0
;
k
<
VPT
;
k
++
)
{
float
a
=
a_float
[
k
];
float
b
=
b_float
[
k
];
acc
[
m_idx
]
+=
a
*
b
;
}
}
}
// Perform warp-level reduction
int
const
warpSize
=
32
;
int
const
warpId
=
tid
/
warpSize
;
int
const
laneId
=
tid
%
warpSize
;
// Register for warp-level reduction results
float
warp_result
[
kNumTokens
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
kNumTokens
;
m_idx
++
)
{
warp_result
[
m_idx
]
=
acc
[
m_idx
];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for
(
int
m
=
0
;
m
<
kNumTokens
;
m
++
)
{
float
sum
=
warp_result
[
m
];
// Butterfly reduction pattern
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
16
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
8
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
4
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
2
);
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
1
);
// Only the first thread in each warp stores to shared memory
if
(
laneId
==
0
)
{
sm_reduction
[
m
][
warpId
]
=
sum
;
}
}
__syncthreads
();
// Final reduction across warps (only first thread)
if
(
tid
==
0
)
{
#pragma unroll
for
(
int
m
=
0
;
m
<
kNumTokens
;
m
++
)
{
float
final_sum
=
0.0
f
;
// Sum across the kNumWarps
#pragma unroll
for
(
int
w
=
0
;
w
<
kNumWarps
;
w
++
)
{
final_sum
+=
sm_reduction
[
m
][
w
];
}
// Write final result
out
[
m
*
kNumExperts
+
n_idx
]
=
__float2bfloat16
(
final_sum
);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmBf16Output
(
__nv_bfloat16
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
constexpr
int
VPT
=
16
/
sizeof
(
T
);
constexpr
int
kBlockSize
=
128
;
cudaLaunchConfig_t
config
;
config
.
gridDim
=
kNumExperts
;
config
.
blockDim
=
kBlockSize
;
config
.
dynamicSmemBytes
=
0
;
config
.
stream
=
stream
;
cudaLaunchAttribute
attrs
[
1
];
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
getEnvEnablePDL
();
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
router_gemm_kernel_bf16_output
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
}
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
1
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
2
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
3
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
4
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
5
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
6
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
7
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
8
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
9
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
10
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
11
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
12
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
13
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
14
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
15
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
16
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
0 → 100644
View file @
282eb59f
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmFloatOutput
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
);
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmBf16Output
(
__nv_bfloat16
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
);
template
<
int
kBegin
,
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
{
static
void
unroll_float_output
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll_float_output
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
static
void
unroll_bf16_output
(
int
num_tokens
,
__nv_bfloat16
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll_bf16_output
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
};
template
<
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
<
kEnd
,
kEnd
,
kNumExperts
,
kHiddenDim
>
{
static
void
unroll_float_output
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
static
void
unroll_bf16_output
(
int
num_tokens
,
__nv_bfloat16
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
};
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
// [num_tokens, num_experts]
const
torch
::
Tensor
&
mat_a
,
// [num_tokens, hidden_dim]
const
torch
::
Tensor
&
mat_b
// [num_experts, hidden_dim]
)
{
TORCH_CHECK
(
output
.
dim
()
==
2
&&
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
);
const
int
num_tokens
=
mat_a
.
size
(
0
);
constexpr
int
num_experts
=
256
;
constexpr
int
hidden_dim
=
7168
;
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
1
),
"mat_a and mat_b must have the same hidden_dim"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
hidden_dim
,
"currently hidden_dim only supports 7168"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
==
num_experts
,
"currently num_experts only supports 256"
);
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"currently num_tokens must be less than or equal to 16 for router_gemm"
);
TORCH_CHECK
(
mat_a
.
dtype
()
==
torch
::
kBFloat16
,
"mat_a must be bf16"
);
TORCH_CHECK
(
mat_b
.
dtype
()
==
torch
::
kBFloat16
,
"mat_b must be bf16"
);
TORCH_CHECK
(
output
.
dtype
()
==
torch
::
kFloat32
||
output
.
dtype
()
==
torch
::
kBFloat16
,
"output must be float32 or bf16"
);
auto
const
sm
=
getSMVersion
();
TORCH_CHECK
(
sm
>=
90
,
"required CUDA ARCH >= SM_90"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_float_output
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_bf16_output
(
num_tokens
,
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
}
sgl-kernel/csrc/gemm/dsv3_router_gemm.cu
→
sgl-kernel/csrc/gemm/dsv3_router_gemm
_float_out
.cu
View file @
282eb59f
...
@@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds
...
@@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds
}
}
template
<
typename
T
,
int
kBlockSize
,
int
VPT
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
template
<
typename
T
,
int
kBlockSize
,
int
VPT
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel
(
float
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
__global__
__launch_bounds__
(
128
,
1
)
void
router_gemm_kernel
_float_output
(
float
*
out
,
T
const
*
mat_a
,
T
const
*
mat_b
)
{
// Each block handles one expert column
// Each block handles one expert column
int
const
n_idx
=
blockIdx
.
x
;
int
const
n_idx
=
blockIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
int
const
tid
=
threadIdx
.
x
;
...
@@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const
...
@@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const
}
}
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemm
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
void
invokeRouterGemm
FloatOutput
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
)
{
constexpr
int
VPT
=
16
/
sizeof
(
T
);
constexpr
int
VPT
=
16
/
sizeof
(
T
);
constexpr
int
kBlockSize
=
128
;
constexpr
int
kBlockSize
=
128
;
cudaLaunchConfig_t
config
;
cudaLaunchConfig_t
config
;
...
@@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_
...
@@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_
config
.
numAttrs
=
1
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
cudaLaunchKernelEx
(
&
config
,
router_gemm_kernel
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
&
config
,
router_gemm_kernel_float_output
<
T
,
kBlockSize
,
VPT
,
kNumTokens
,
kNumExperts
,
kHiddenDim
>
,
output
,
mat_a
,
mat_b
);
}
}
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
1
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
1
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
2
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
2
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
3
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
3
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
4
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
4
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
5
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
5
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
6
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
6
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
7
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
7
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
8
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
8
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
9
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
9
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
10
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
10
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
11
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
11
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
12
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
12
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
13
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
13
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
14
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
14
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
15
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
15
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
16
,
256
,
7168
>(
invokeRouterGemm
<
__nv_bfloat16
,
16
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
<
int
kBegin
,
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
{
static
void
unroll
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kBegin
)
{
invokeRouterGemm
<
__nv_bfloat16
,
kBegin
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
LoopUnroller
<
kBegin
+
1
,
kEnd
,
kNumExperts
,
kHiddenDim
>::
unroll
(
num_tokens
,
output
,
input
,
weights
,
stream
);
}
}
};
template
<
int
kEnd
,
int
kNumExperts
,
int
kHiddenDim
>
struct
LoopUnroller
<
kEnd
,
kEnd
,
kNumExperts
,
kHiddenDim
>
{
static
void
unroll
(
int
num_tokens
,
float
*
output
,
__nv_bfloat16
const
*
input
,
__nv_bfloat16
const
*
weights
,
cudaStream_t
stream
)
{
if
(
num_tokens
==
kEnd
)
{
invokeRouterGemm
<
__nv_bfloat16
,
kEnd
,
kNumExperts
,
kHiddenDim
>
(
output
,
input
,
weights
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Invalid num_tokens, only supports 1 to 16"
);
}
}
};
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
// [num_tokens, num_experts]
const
torch
::
Tensor
&
mat_a
,
// [num_tokens, hidden_dim]
const
torch
::
Tensor
&
mat_b
// [num_experts, hidden_dim]
)
{
TORCH_CHECK
(
output
.
dim
()
==
2
&&
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
);
const
int
num_tokens
=
mat_a
.
size
(
0
);
constexpr
int
num_experts
=
256
;
constexpr
int
hidden_dim
=
7168
;
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
1
),
"mat_a and mat_b must have the same hidden_dim"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
hidden_dim
,
"currently hidden_dim only supports 7168"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
==
num_experts
,
"currently num_experts only supports 256"
);
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"currently num_tokens must be less than or equal to 16 for router_gemm"
);
TORCH_CHECK
(
mat_a
.
dtype
()
==
torch
::
kBFloat16
,
"mat_a must be bf16"
);
TORCH_CHECK
(
mat_b
.
dtype
()
==
torch
::
kBFloat16
,
"mat_b must be bf16"
);
TORCH_CHECK
(
output
.
dtype
()
==
torch
::
kFloat32
,
"output must be float32"
);
auto
const
sm
=
getSMVersion
();
TORCH_CHECK
(
sm
>=
90
,
"required CUDA ARCH >= SM_90"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
sgl-kernel/python/sgl_kernel/gemm.py
View file @
282eb59f
...
@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
...
@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
def
dsv3_router_gemm
(
def
dsv3_router_gemm
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_weights
:
torch
.
Tensor
,
router_weights
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty
(
output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
0
],
router_weights
.
shape
[
0
],
router_weights
.
shape
[
0
],
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
dtype
=
out_dtype
,
)
)
torch
.
ops
.
sgl_kernel
.
dsv3_router_gemm
(
torch
.
ops
.
sgl_kernel
.
dsv3_router_gemm
(
output
,
output
,
...
...
sgl-kernel/tests/test_dsv3_router_gemm.py
View file @
282eb59f
...
@@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens):
...
@@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens):
mat_b
=
torch
.
randn
(
mat_b
=
torch
.
randn
(
(
num_experts
,
hidden_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
(
num_experts
,
hidden_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
).
contiguous
()
output
=
torch
.
empty
(
(
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
contiguous
()
ref
=
F
.
linear
(
mat_a
,
mat_b
).
to
(
torch
.
float32
)
bf16_ref
=
F
.
linear
(
mat_a
,
mat_b
)
float_ref
=
bf16_ref
.
to
(
torch
.
float32
)
bf16_output
=
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
bfloat16
)
float_output
=
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
float32
)
output
=
dsv3_router_gemm
(
mat_a
,
mat_b
)
assert
torch
.
allclose
(
bf16_output
,
bf16_ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert
torch
.
allclose
(
assert
torch
.
allclose
(
output
,
ref
,
rtol
=
1e-2
,
atol
=
1e-3
float_
output
,
float_
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Router GEMM output mismatch with torch.nn.functional.linear reference"
),
"Router GEMM output
in float32 dtype
mismatch with torch.nn.functional.linear reference"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment