Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04b35190
Unverified
Commit
04b35190
authored
Jun 29, 2025
by
Ke Bao
Committed by
GitHub
Jun 29, 2025
Browse files
Add dsv3 fused a gemm to sgl-kernel (#7630)
parent
071a1f51
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
800 additions
and
0 deletions
+800
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
+57
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-0
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
+672
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+17
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-0
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
+32
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
04b35190
...
...
@@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
argparse
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.testing
from
sgl_kernel
import
dsv3_fused_a_gemm
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch (bf16)"
,
"dsv3_fused_a_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"bf16 dsv3 fused a GEMM throughput"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
impl
):
kHdIn
=
7168
kHdOut
=
2112
M
,
K
,
N
=
num_tokens
,
kHdIn
,
kHdOut
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
.
T
)
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
M
*
K
*
N
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_gemm"
)
sgl-kernel/csrc/common_extension.cc
View file @
04b35190
...
...
@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"
);
m
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
m
.
def
(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
m
.
impl
(
"dsv3_fused_a_gemm"
,
torch
::
kCUDA
,
&
dsv3_fused_a_gemm
);
// Compute NVFP4 experts quantization.
m
.
def
(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
...
...
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
0 → 100644
View file @
04b35190
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/619709fc33bd5dc268f19d6a741fe7ed51c0f8f5/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3FusedAGemm.cu
*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "utils.h"
using
bf16_t
=
__nv_bfloat16
;
__device__
void
hmma_16_8_16_f32acc_bf16ab
(
float
(
&
d_reg
)[
4
],
const
bf16_t
(
&
a_reg
)[
8
],
const
bf16_t
(
&
b_reg
)[
4
],
float
const
(
&
c_reg
)[
4
])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
a0
=
*
reinterpret_cast
<
uint32_t
const
*>
(
a_reg
+
0
);
uint32_t
a1
=
*
reinterpret_cast
<
uint32_t
const
*>
(
a_reg
+
2
);
uint32_t
a2
=
*
reinterpret_cast
<
uint32_t
const
*>
(
a_reg
+
4
);
uint32_t
a3
=
*
reinterpret_cast
<
uint32_t
const
*>
(
a_reg
+
6
);
uint32_t
b0
=
*
reinterpret_cast
<
uint32_t
const
*>
(
b_reg
+
0
);
uint32_t
b1
=
*
reinterpret_cast
<
uint32_t
const
*>
(
b_reg
+
2
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=f"
(
d_reg
[
0
]),
"=f"
(
d_reg
[
1
]),
"=f"
(
d_reg
[
2
]),
"=f"
(
d_reg
[
3
])
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
b0
),
"r"
(
b1
),
"f"
(
d_reg
[
0
]),
"f"
(
d_reg
[
1
]),
"f"
(
d_reg
[
2
]),
"f"
(
d_reg
[
3
]));
#endif
}
extern
"C"
{
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
);
}
__device__
void
ldgsts_128
(
void
const
*
gPtr
,
void
*
sPtr
,
uint32_t
pred
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
if
(
pred
)
{
uint32_t
smemPtrAsUint32
=
__nvvm_get_smem_pointer
(
sPtr
);
asm
volatile
(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;
\n
"
::
"r"
(
smemPtrAsUint32
),
"l"
(
gPtr
),
"n"
(
16
));
}
#endif
}
__device__
void
ldsm_x4
(
void
*
smem_ptr
,
uint32_t
*
reg_ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
reg_ptr
[
0
]),
"=r"
(
reg_ptr
[
1
]),
"=r"
(
reg_ptr
[
2
]),
"=r"
(
reg_ptr
[
3
])
:
"r"
(
__nvvm_get_smem_pointer
(
smem_ptr
)));
#endif
}
template
<
class
Type
>
__device__
int
apply_swizzle_343_on_elem_row_col
(
int
row_idx_
,
int
col_idx_
)
{
uint32_t
row_idx
=
*
reinterpret_cast
<
uint32_t
*>
(
&
row_idx_
);
uint32_t
col_idx
=
*
reinterpret_cast
<
uint32_t
*>
(
&
col_idx_
);
row_idx
=
row_idx
%
8
;
row_idx
=
row_idx
*
(
16
/
sizeof
(
Type
));
col_idx
=
col_idx
^
row_idx
;
return
*
reinterpret_cast
<
int
*>
(
&
col_idx
);
}
__device__
void
initialize_barrier
(
uint64_t
*
smem_barrier
,
// 64 bits user-manged barrier in smem
int
thread_count
=
1
)
// Thread count expected to arrive/wait on this barrier
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
smem_int_ptr
=
__nvvm_get_smem_pointer
(
smem_barrier
);
asm
volatile
(
"mbarrier.init.shared::cta.b64 [%0], %1;
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
thread_count
));
#endif
}
// Barrier wait
__device__
void
wait_barrier
(
uint64_t
*
smem_barrier
,
// 64 bits user-manged barrier in smem
int
phase_bit
)
// Current phase bit the barrier waiting to flip
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
smem_int_ptr
=
__nvvm_get_smem_pointer
(
smem_barrier
);
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;
\n
"
"@P1 bra DONE;
\n
"
"bra LAB_WAIT;
\n
"
"DONE:
\n
"
"}
\n
"
::
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
#endif
}
__device__
bool
try_wait_barrier
(
uint64_t
*
smem_ptr
,
int
phase_bit
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
wait_complete
;
uint32_t
smem_int_ptr
=
__nvvm_get_smem_pointer
(
smem_ptr
);
asm
volatile
(
"{
\n\t
"
".reg .pred P1;
\n\t
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2;
\n\t
"
"selp.b32 %0, 1, 0, P1;
\n\t
"
"}"
:
"=r"
(
wait_complete
)
:
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
return
static_cast
<
bool
>
(
wait_complete
);
#endif
}
// Barrier arrive
__device__
void
arrive_barrier
(
uint64_t
*
smem_barrier
)
// 64 bits user-manged barrier in smem
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
smem_int_ptr
=
__nvvm_get_smem_pointer
(
smem_barrier
);
asm
volatile
(
"{
\n
"
".reg .b64 state;
\n
"
"mbarrier.arrive.shared::cta.b64 state, [%0];
\n
"
"}
\n
"
::
"r"
(
smem_int_ptr
));
#endif
}
__device__
void
ldgsts_arrive
(
uint64_t
*
smem_barrier
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
uint32_t
smem_int_ptr
=
__nvvm_get_smem_pointer
(
smem_barrier
);
asm
volatile
(
"cp.async.mbarrier.arrive.noinc.shared.b64 [%0];"
:
:
"r"
(
smem_int_ptr
));
#endif
}
template
<
int
gemm_k
,
int
tile_m
,
int
tile_k
,
int
stage_cnt
>
struct
GmemLoaderA
{
static
constexpr
int
elem_bytes
=
2
;
static
constexpr
int
vec_bytes
=
16
;
static
constexpr
int
vec_elems
=
vec_bytes
/
elem_bytes
;
static
constexpr
int
thread_cnt
=
64
;
static_assert
((
tile_m
*
tile_k
)
%
(
vec_elems
*
thread_cnt
)
==
0
);
static
constexpr
int
a_inst_cnt_per_iter
=
(
tile_m
*
tile_k
)
/
(
vec_elems
*
thread_cnt
);
static_assert
(
gemm_k
%
tile_k
==
0
);
static
constexpr
int
k_iter_cnt
=
gemm_k
/
tile_k
;
// Extra params to keep the order of k reduction...
static
constexpr
int
mma_warp_cnt
=
4
;
static
constexpr
int
per_mma_warp_k
=
tile_k
/
mma_warp_cnt
;
static
constexpr
int
k_each_chunk
=
gemm_k
/
mma_warp_cnt
;
private:
__device__
int
k_project
(
int
tile_k_idx
)
{
return
(
tile_k_idx
/
per_mma_warp_k
*
k_each_chunk
)
+
(
tile_k_idx
%
per_mma_warp_k
);
}
public:
__device__
GmemLoaderA
(
bf16_t
const
*
gmem_a_local_
,
bf16_t
*
smem_a_
,
uint64_t
*
smem_barrier_
)
:
gmem_a
(
gmem_a_local_
),
smem_a
(
smem_a_
),
smem_barrier
(
smem_barrier_
),
local_tid
(
threadIdx
.
x
%
thread_cnt
)
{}
__device__
void
prepare
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// swizzle, that's what we want.
#pragma unroll
for
(
int
i
=
0
;
i
<
a_inst_cnt_per_iter
;
i
++
)
{
int
linear_idx
=
local_tid
*
vec_elems
+
i
*
thread_cnt
*
vec_elems
;
int
m_idx
=
linear_idx
/
tile_k
;
int
k_idx
=
linear_idx
%
tile_k
;
k_idx
=
apply_swizzle_343_on_elem_row_col
<
bf16_t
>
(
m_idx
,
k_idx
);
a_smem_offsets
[
i
]
=
m_idx
*
tile_k
+
k_idx
;
}
#endif
}
__device__
void
issue_mainloop
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll 1
for
(
int
loop_idx
=
0
;
loop_idx
<
k_iter_cnt
;
loop_idx
++
)
{
if
(
need_wait
)
{
wait_barrier
(
smem_barrier
+
1
+
stage_idx
*
2
,
phase_bit
);
}
int
next_stage_idx
=
stage_idx
+
1
;
int
next_phase_bit
=
next_stage_idx
==
stage_cnt
?
phase_bit
^
1
:
phase_bit
;
next_stage_idx
=
next_stage_idx
==
stage_cnt
?
0
:
next_stage_idx
;
if
(
loop_idx
!=
k_iter_cnt
-
1
)
{
need_wait
=
!
try_wait_barrier
(
smem_barrier
+
1
+
next_stage_idx
*
2
,
next_phase_bit
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
a_inst_cnt_per_iter
;
i
++
)
{
int
smem_offset
=
a_smem_offsets
[
i
];
bf16_t
*
smem_ptr_this_iter
=
smem_a
+
stage_idx
*
tile_m
*
tile_k
+
smem_offset
;
int
linear_idx
=
local_tid
*
vec_elems
+
i
*
thread_cnt
*
vec_elems
;
int
m_idx
=
linear_idx
/
tile_k
;
int
k_idx
=
linear_idx
%
tile_k
;
int
gmem_offset
=
m_idx
*
gemm_k
+
k_project
(
k_idx
);
bf16_t
const
*
gmem_ptr_this_iter
=
gmem_a
+
gmem_offset
;
ldgsts_128
(
gmem_ptr_this_iter
,
smem_ptr_this_iter
,
true
);
}
ldgsts_arrive
(
smem_barrier
+
stage_idx
*
2
);
stage_idx
=
next_stage_idx
;
phase_bit
=
next_phase_bit
;
gmem_a
+=
per_mma_warp_k
;
}
#endif
}
bf16_t
const
*
gmem_a
;
bf16_t
*
smem_a
;
uint64_t
*
smem_barrier
;
int
local_tid
;
int
stage_idx
=
0
;
int
phase_bit
=
1
;
bool
need_wait
=
true
;
// per smem_stage, store with swizzle information
int
a_smem_offsets
[
a_inst_cnt_per_iter
];
};
template
<
int
gemm_k
,
int
tile_n
,
int
tile_k
,
int
stage_cnt
>
struct
GmemLoaderB
{
static
constexpr
int
elem_bytes
=
2
;
static
constexpr
int
vec_bytes
=
16
;
static
constexpr
int
vec_elems
=
vec_bytes
/
elem_bytes
;
static
constexpr
int
thread_cnt
=
64
;
static_assert
((
tile_n
*
tile_k
)
%
(
vec_elems
*
thread_cnt
)
==
0
);
static
constexpr
int
b_inst_cnt_per_iter
=
(
tile_n
*
tile_k
)
/
(
vec_elems
*
thread_cnt
);
static_assert
(
gemm_k
%
tile_k
==
0
);
static
constexpr
int
k_iter_cnt
=
gemm_k
/
tile_k
;
// Extra params to keep the order of k reduction...
static
constexpr
int
mma_warp_cnt
=
4
;
static
constexpr
int
per_mma_warp_k
=
tile_k
/
mma_warp_cnt
;
static
constexpr
int
k_each_chunk
=
gemm_k
/
mma_warp_cnt
;
private:
__device__
int
k_project
(
int
tile_k_idx
)
{
return
(
tile_k_idx
/
per_mma_warp_k
*
k_each_chunk
)
+
(
tile_k_idx
%
per_mma_warp_k
);
}
public:
__device__
GmemLoaderB
(
bf16_t
const
*
gmem_b_local_
,
bf16_t
*
smem_b_
,
uint64_t
*
smem_barrier_
,
int
gemm_n_
)
:
gmem_b
(
gmem_b_local_
),
smem_b
(
smem_b_
),
smem_barrier
(
smem_barrier_
),
gemm_n
(
gemm_n_
),
local_tid
(
threadIdx
.
x
%
thread_cnt
)
{}
__device__
void
prepare
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// swizzle, that's what we want.
#pragma unroll
for
(
int
i
=
0
;
i
<
b_inst_cnt_per_iter
;
i
++
)
{
int
linear_idx
=
local_tid
*
vec_elems
+
i
*
thread_cnt
*
vec_elems
;
int
n_idx
=
linear_idx
/
tile_k
;
int
k_idx
=
linear_idx
%
tile_k
;
k_idx
=
apply_swizzle_343_on_elem_row_col
<
bf16_t
>
(
n_idx
,
k_idx
);
b_smem_offsets
[
i
]
=
n_idx
*
tile_k
+
k_idx
;
preds
[
i
]
=
n_idx
<
gemm_n
;
}
#endif
}
__device__
void
issue_mainloop
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm
volatile
(
"griddepcontrol.wait;"
);
#pragma unroll 1
for
(
int
loop_idx
=
0
;
loop_idx
<
k_iter_cnt
;
loop_idx
++
)
{
if
(
need_wait
)
{
wait_barrier
(
smem_barrier
+
1
+
stage_idx
*
2
,
phase_bit
);
}
int
next_stage_idx
=
stage_idx
+
1
;
int
next_phase_bit
=
next_stage_idx
==
stage_cnt
?
phase_bit
^
1
:
phase_bit
;
next_stage_idx
=
next_stage_idx
==
stage_cnt
?
0
:
next_stage_idx
;
if
(
loop_idx
!=
k_iter_cnt
-
1
)
{
need_wait
=
!
try_wait_barrier
(
smem_barrier
+
1
+
next_stage_idx
*
2
,
next_phase_bit
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
b_inst_cnt_per_iter
;
i
++
)
{
int
smem_offset
=
b_smem_offsets
[
i
];
bf16_t
*
smem_ptr_this_iter
=
smem_b
+
stage_idx
*
tile_n
*
tile_k
+
smem_offset
;
int
linear_idx
=
local_tid
*
vec_elems
+
i
*
thread_cnt
*
vec_elems
;
int
n_idx
=
linear_idx
/
tile_k
;
int
k_idx
=
linear_idx
%
tile_k
;
int
gmem_offset
=
n_idx
*
gemm_k
+
k_project
(
k_idx
);
bf16_t
const
*
gmem_ptr_this_iter
=
gmem_b
+
gmem_offset
;
ldgsts_128
(
gmem_ptr_this_iter
,
smem_ptr_this_iter
,
preds
[
i
]);
}
ldgsts_arrive
(
smem_barrier
+
stage_idx
*
2
);
stage_idx
=
next_stage_idx
;
phase_bit
=
next_phase_bit
;
gmem_b
+=
per_mma_warp_k
;
}
#endif
}
bf16_t
const
*
gmem_b
;
bf16_t
*
smem_b
;
uint64_t
*
smem_barrier
;
int
gemm_n
;
int
local_tid
;
int
stage_idx
=
0
;
int
phase_bit
=
1
;
bool
need_wait
=
true
;
// per smem_stage, store with swizzle information
int
b_smem_offsets
[
b_inst_cnt_per_iter
];
uint32_t
preds
[
b_inst_cnt_per_iter
];
};
template
<
int
gemm_m
,
int
gemm_k
,
int
tile_m
,
int
tile_n
,
int
tile_k
,
int
stage_cnt
>
struct
MmaComputer
{
static
constexpr
int
elem_bytes
=
2
;
static
constexpr
int
thread_cnt
=
128
;
static_assert
(
gemm_k
%
tile_k
==
0
);
static_assert
(
tile_k
%
(
thread_cnt
/
32
)
==
0
);
static
constexpr
int
per_warp_tile_k
=
tile_k
/
(
thread_cnt
/
32
);
static
constexpr
int
k_iter_cnt
=
gemm_k
/
tile_k
;
static
constexpr
int
k_phase_cnt
=
per_warp_tile_k
/
16
;
static
constexpr
int
m_iter_cnt
=
(
tile_m
+
15
)
/
16
;
static
constexpr
int
n_iter_cnt
=
(
tile_n
+
7
)
/
8
;
// Possible to have non-1 n_iter_cnt for ab_swap m16 case.
static_assert
(
m_iter_cnt
==
1
);
static_assert
(
n_iter_cnt
==
1
||
n_iter_cnt
==
2
);
__device__
MmaComputer
(
bf16_t
*
gmem_c_local_
,
bf16_t
*
smem_a_
,
bf16_t
*
smem_b_
,
uint64_t
*
smem_barrier_
,
int
warp_idx_
,
int
gemm_n_
)
:
gmem_c
(
gmem_c_local_
),
smem_a
(
smem_a_
),
smem_b
(
smem_b_
),
smem_barrier
(
smem_barrier_
),
warp_idx
(
warp_idx_
-
(
thread_cnt
/
32
)),
gemm_n
(
gemm_n_
)
{}
private:
__device__
constexpr
int
internal_b_atom_func
(
int
tid
)
{
if
constexpr
(
tile_n
<
8
)
{
return
(
tid
%
tile_n
)
+
((
tid
%
8
)
/
tile_n
*
0
)
+
tid
/
8
*
8
*
tile_n
;
}
else
{
return
(
tid
%
8
)
+
((
tid
%
32
)
/
8
*
(
tile_n
*
8
));
}
}
public:
__device__
void
prepare
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll
for
(
int
i
=
0
;
i
<
k_phase_cnt
;
i
++
)
{
int
linear_idx
=
(
lane_idx
%
16
)
+
(
lane_idx
/
16
)
*
128
+
i
*
256
;
int
m_idx
=
linear_idx
%
tile_m
;
int
k_idx
=
linear_idx
/
tile_m
+
warp_k_offset_in_tile_k
;
k_idx
=
apply_swizzle_343_on_elem_row_col
<
bf16_t
>
(
m_idx
,
k_idx
);
a_smem_offsets
[
0
][
i
]
=
m_idx
*
tile_k
+
k_idx
;
}
#pragma unroll
for
(
int
n_iter_idx
=
0
;
n_iter_idx
<
n_iter_cnt
;
n_iter_idx
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
k_phase_cnt
;
i
+=
2
)
{
// Special i+=2 for B.
int
linear_idx
=
internal_b_atom_func
(
lane_idx
)
+
i
*
tile_n
*
16
+
n_iter_idx
*
8
;
int
n_idx
=
linear_idx
%
tile_n
;
int
k_idx
=
linear_idx
/
tile_n
+
warp_k_offset_in_tile_k
;
k_idx
=
apply_swizzle_343_on_elem_row_col
<
bf16_t
>
(
n_idx
,
k_idx
);
b_smem_offsets
[
n_iter_idx
][
i
]
=
n_idx
*
tile_k
+
k_idx
;
}
}
#endif
}
__device__
void
issue_mainloop
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#pragma unroll 1
for
(
int
loop_idx
=
0
;
loop_idx
<
k_iter_cnt
;
loop_idx
++
)
{
wait_barrier
(
smem_barrier
+
0
+
stage_idx
*
2
,
phase_bit
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k_phase_cnt
;
i
++
)
{
int
smem_offset
=
a_smem_offsets
[
0
][
i
];
bf16_t
*
smem_ptr_this_iter
=
smem_a
+
stage_idx
*
tile_m
*
tile_k
+
smem_offset
;
ldsm_x4
(
smem_ptr_this_iter
,
reinterpret_cast
<
uint32_t
*>
(
a_reg
[
0
][
i
]));
}
#pragma unroll
for
(
int
n_iter_idx
=
0
;
n_iter_idx
<
n_iter_cnt
;
n_iter_idx
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
k_phase_cnt
;
i
+=
2
)
{
int
smem_offset
=
b_smem_offsets
[
n_iter_idx
][
i
];
bf16_t
*
smem_ptr_this_iter
=
smem_b
+
stage_idx
*
tile_n
*
tile_k
+
smem_offset
;
ldsm_x4
(
smem_ptr_this_iter
,
reinterpret_cast
<
uint32_t
*>
(
b_reg
[
n_iter_idx
][
i
]));
}
}
#pragma unroll
for
(
int
k_iter_idx
=
0
;
k_iter_idx
<
k_phase_cnt
;
k_iter_idx
++
)
{
#pragma unroll
for
(
int
n_iter_idx
=
0
;
n_iter_idx
<
n_iter_cnt
;
n_iter_idx
++
)
{
hmma_16_8_16_f32acc_bf16ab
(
acc_reg
[
0
][
n_iter_idx
],
a_reg
[
0
][
k_iter_idx
],
b_reg
[
n_iter_idx
][
k_iter_idx
],
acc_reg
[
0
][
n_iter_idx
]);
}
}
::
arrive_barrier
(
smem_barrier
+
1
+
stage_idx
*
2
);
stage_idx
+=
1
;
phase_bit
=
stage_idx
==
stage_cnt
?
phase_bit
^
1
:
phase_bit
;
stage_idx
=
stage_idx
==
stage_cnt
?
0
:
stage_idx
;
}
#endif
}
__device__
void
epi
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm
volatile
(
"bar.sync %0, %1;"
:
:
"r"
(
1
),
"r"
(
thread_cnt
));
// reorganize the acc_reg
constexpr
int
thread_m
=
2
;
constexpr
int
thread_n
=
2
*
n_iter_cnt
;
constexpr
int
cta_mma_n
=
n_iter_cnt
*
8
;
float
acc_reg_reorg
[
thread_m
][
thread_n
];
for
(
int
i
=
0
;
i
<
thread_m
;
i
++
)
{
for
(
int
j
=
0
;
j
<
thread_n
;
j
++
)
{
acc_reg_reorg
[
i
][
j
]
=
acc_reg
[
0
][
j
/
2
][(
j
%
2
)
+
(
i
*
2
)];
}
}
// 4 x cosize(smem_c_layout)
float
*
smem_c
=
reinterpret_cast
<
float
*>
(
smem_a
);
// coord -> index
auto
smem_c_index_func
=
[
&
](
int
m_idx
,
int
n_idx
)
{
int
group_rows
=
32
/
cta_mma_n
;
int
group_cnt
=
2
;
return
(
m_idx
%
group_rows
*
cta_mma_n
)
+
(
m_idx
/
group_rows
*
(
32
+
group_cnt
))
+
n_idx
;
};
constexpr
int
cosize_smem_c
=
((
tile_m
*
cta_mma_n
)
/
32
)
*
(
32
+
2
);
// This should be optimized to STS.64 but can not be STS.128 due to the bank index.
#pragma unroll
for
(
int
m_idx_thread
=
0
;
m_idx_thread
<
thread_m
;
m_idx_thread
++
)
{
#pragma unroll
for
(
int
n_idx_thread
=
0
;
n_idx_thread
<
thread_n
;
n_idx_thread
++
)
{
int
m_idx
=
(
lane_idx
/
4
)
+
m_idx_thread
*
8
;
int
n_idx
=
((
lane_idx
%
4
)
*
2
)
+
(
n_idx_thread
%
2
)
+
(
n_idx_thread
/
2
)
*
8
;
smem_c
[
cosize_smem_c
*
warp_idx
+
smem_c_index_func
(
m_idx
,
n_idx
)]
=
acc_reg_reorg
[
m_idx_thread
][
n_idx_thread
];
}
}
asm
volatile
(
"bar.sync %0, %1;"
:
:
"r"
(
1
),
"r"
(
thread_cnt
));
if
(
warp_idx
==
0
)
{
constexpr
int
final_acc_reg_cnt
=
(
tile_m
*
tile_n
+
31
)
/
32
;
float
acc_final
[
final_acc_reg_cnt
]{};
#pragma unroll
for
(
int
reg_idx
=
0
;
reg_idx
<
final_acc_reg_cnt
;
reg_idx
++
)
{
int
linear_idx
=
reg_idx
*
32
+
lane_idx
;
int
m_idx
=
linear_idx
%
tile_m
;
int
n_idx
=
linear_idx
/
tile_m
;
acc_final
[
reg_idx
]
+=
smem_c
[
smem_c_index_func
(
m_idx
,
n_idx
)
+
0
*
cosize_smem_c
]
+
smem_c
[
smem_c_index_func
(
m_idx
,
n_idx
)
+
1
*
cosize_smem_c
]
+
smem_c
[
smem_c_index_func
(
m_idx
,
n_idx
)
+
2
*
cosize_smem_c
]
+
smem_c
[
smem_c_index_func
(
m_idx
,
n_idx
)
+
3
*
cosize_smem_c
];
}
#pragma unroll
for
(
int
reg_idx
=
0
;
reg_idx
<
final_acc_reg_cnt
;
reg_idx
++
)
{
int
linear_idx
=
reg_idx
*
32
+
lane_idx
;
int
m_idx
=
linear_idx
%
tile_m
;
int
n_idx
=
linear_idx
/
tile_m
;
if
(
m_idx
<
tile_m
&&
n_idx
<
gemm_n
)
{
gmem_c
[
n_idx
*
gemm_m
+
m_idx
]
=
acc_final
[
reg_idx
];
}
}
}
#endif
}
bf16_t
*
gmem_c
;
bf16_t
*
smem_a
;
bf16_t
*
smem_b
;
uint64_t
*
smem_barrier
;
int
warp_idx
;
int
gemm_n
;
int
stage_idx
=
0
;
int
phase_bit
=
0
;
int
lane_idx
=
threadIdx
.
x
%
32
;
int
warp_k_offset_in_tile_k
=
warp_idx
*
per_warp_tile_k
;
int
a_smem_offsets
[
m_iter_cnt
][
k_phase_cnt
];
int
b_smem_offsets
[
n_iter_cnt
][
k_phase_cnt
];
bf16_t
a_reg
[
m_iter_cnt
][
k_phase_cnt
][
8
];
bf16_t
b_reg
[
n_iter_cnt
][
k_phase_cnt
][
4
];
float
acc_reg
[
m_iter_cnt
][
n_iter_cnt
][
4
]{};
};
// AB swapped, kernel is k-major, k-major, m-major
template
<
int
batch_size
,
int
gemm_m
,
int
gemm_k
,
int
tile_m
,
int
tile_n
,
int
tile_k
,
int
stage_cnt
>
__global__
__launch_bounds__
(
256
,
1
)
void
fused_a_gemm_kernel
(
bf16_t
*
output
,
bf16_t
const
*
mat_a
,
bf16_t
const
*
mat_b
,
int
gemm_n
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
constexpr
int
load_thread_cnt
=
128
;
constexpr
int
compute_thread_cnt
=
128
;
constexpr
int
thread_cnt
=
load_thread_cnt
+
compute_thread_cnt
;
(
void
)
thread_cnt
;
static_assert
(
gemm_m
%
16
==
0
);
static_assert
(
gemm_k
%
tile_k
==
0
);
static_assert
(
gemm_m
%
tile_m
==
0
);
static_assert
(
tile_k
==
128
||
tile_k
==
256
||
tile_k
==
512
||
tile_k
==
1024
);
// tile_k must be larger than 64 since 4 warp splitK.
static_assert
(
tile_m
==
16
);
constexpr
int
g2s_vec_bytes
=
16
;
constexpr
int
a_elem_bytes
=
2
;
constexpr
int
b_elem_bytes
=
2
;
// constexpr int c_elem_bytes = 2;
static_assert
((
tile_m
*
a_elem_bytes
+
tile_n
*
b_elem_bytes
)
*
tile_k
*
stage_cnt
<=
225
*
1024
);
static_assert
((
tile_m
*
tile_k
*
a_elem_bytes
)
%
(
load_thread_cnt
*
g2s_vec_bytes
)
==
0
);
static_assert
((
tile_n
*
tile_k
*
b_elem_bytes
)
%
(
load_thread_cnt
*
g2s_vec_bytes
)
==
0
);
extern
__shared__
char
smem
[];
uint64_t
*
smem_barrier
=
reinterpret_cast
<
uint64_t
*>
(
smem
);
// producer,consumer; producer,consumer; ...
bf16_t
*
smem_a
=
reinterpret_cast
<
bf16_t
*>
(
smem
+
(
stage_cnt
*
8
*
2
+
1024
)
/
1024
*
1024
);
bf16_t
*
smem_b
=
smem_a
+
tile_m
*
tile_k
*
stage_cnt
;
int
cta_m_idx
=
tile_m
*
blockIdx
.
x
;
int
cta_n_idx
=
tile_n
*
blockIdx
.
y
;
bf16_t
const
*
gmem_a_local
=
mat_a
+
cta_m_idx
*
gemm_k
;
bf16_t
const
*
gmem_b_local
=
mat_b
+
cta_n_idx
*
gemm_k
;
bf16_t
*
gmem_c_local
=
output
+
cta_n_idx
*
gemm_m
+
cta_m_idx
;
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
if
(
warp_idx
==
4
)
{
for
(
int
i
=
0
;
i
<
stage_cnt
;
i
++
)
{
initialize_barrier
(
smem_barrier
+
i
*
2
+
0
,
load_thread_cnt
);
// producer
initialize_barrier
(
smem_barrier
+
i
*
2
+
1
,
compute_thread_cnt
);
// consumer
}
}
__syncthreads
();
if
(
warp_idx
<
2
)
{
GmemLoaderA
<
gemm_k
,
tile_m
,
tile_k
,
stage_cnt
>
a_loader
(
gmem_a_local
,
smem_a
,
smem_barrier
);
a_loader
.
prepare
();
a_loader
.
issue_mainloop
();
}
else
if
(
warp_idx
<
4
)
{
GmemLoaderB
<
gemm_k
,
tile_n
,
tile_k
,
stage_cnt
>
b_loader
(
gmem_b_local
,
smem_b
,
smem_barrier
,
gemm_n
);
b_loader
.
prepare
();
b_loader
.
issue_mainloop
();
}
else
{
MmaComputer
<
gemm_m
,
gemm_k
,
tile_m
,
tile_n
,
tile_k
,
stage_cnt
>
mma_computer
(
gmem_c_local
,
smem_a
,
smem_b
,
smem_barrier
,
warp_idx
,
gemm_n
);
mma_computer
.
prepare
();
mma_computer
.
issue_mainloop
();
mma_computer
.
epi
();
}
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
int
kHdIn
,
int
kHdOut
,
int
kTileN
>
void
invokeFusedAGemm
(
T
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
int
num_tokens
,
cudaStream_t
const
stream
)
{
constexpr
int
gemm_m
=
kHdOut
;
// 2112
int
const
gemm_n
=
num_tokens
;
// 16
constexpr
int
gemm_k
=
kHdIn
;
// 7168
constexpr
int
batch_size
=
1
;
std
::
swap
(
mat_a
,
mat_b
);
constexpr
int
tile_m
=
16
;
constexpr
int
tile_n
=
kTileN
;
// 8 or 16
constexpr
int
tile_k
=
std
::
max
(
256
,
1024
/
tile_n
);
// 256
constexpr
int
max_stage_cnt
=
1024
*
192
/
((
tile_m
+
tile_n
)
*
tile_k
*
sizeof
(
bf16_t
));
constexpr
int
k_iter_cnt
=
gemm_k
/
tile_k
;
constexpr
int
stage_cnt
=
k_iter_cnt
>
max_stage_cnt
?
max_stage_cnt
:
k_iter_cnt
;
// possible tunable for smallK > 1 wave n. // 22
int
cta_m_cnt
=
gemm_m
/
tile_m
;
int
cta_n_cnt
=
(
gemm_n
+
tile_n
-
1
)
/
tile_n
;
constexpr
int
barrier_bytes
=
(
stage_cnt
*
16
+
1023
)
/
1024
*
1024
;
// 4096
constexpr
int
smem_bytes
=
((
tile_m
*
2
+
tile_n
*
2
)
*
tile_k
*
stage_cnt
+
barrier_bytes
+
1023
)
/
1024
*
1024
;
dim3
grid
(
cta_m_cnt
,
cta_n_cnt
,
1
);
dim3
block_size
(
256
);
cudaLaunchConfig_t
config
;
config
.
gridDim
=
grid
;
config
.
blockDim
=
block_size
;
config
.
dynamicSmemBytes
=
smem_bytes
;
config
.
stream
=
stream
;
cudaLaunchAttribute
attrs
[
1
];
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
getEnvEnablePDL
();
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
if
(
smem_bytes
>=
(
48
*
1024
))
{
cudaFuncSetAttribute
(
fused_a_gemm_kernel
<
batch_size
,
gemm_m
,
gemm_k
,
tile_m
,
tile_n
,
tile_k
,
stage_cnt
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
cudaLaunchKernelEx
(
&
config
,
fused_a_gemm_kernel
<
batch_size
,
gemm_m
,
gemm_k
,
tile_m
,
tile_n
,
tile_k
,
stage_cnt
>
,
output
,
mat_a
,
mat_b
,
gemm_n
);
}
template
void
invokeFusedAGemm
<
__nv_bfloat16
,
7168
,
2112
,
8
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
int
num_tokens
,
cudaStream_t
);
template
void
invokeFusedAGemm
<
__nv_bfloat16
,
7168
,
2112
,
16
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
int
num_tokens
,
cudaStream_t
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
)
{
TORCH_CHECK
(
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
&&
output
.
dim
()
==
2
);
int
const
num_tokens
=
mat_a
.
size
(
0
);
int
const
hd_in
=
mat_a
.
size
(
1
);
int
const
hd_out
=
mat_b
.
size
(
1
);
constexpr
int
kHdIn
=
7168
;
constexpr
int
kHdOut
=
2112
;
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"required 1 <= mat_a.shape[0] <= 16"
)
TORCH_CHECK
(
hd_in
==
kHdIn
,
"required mat_a.shape[1] == 7168"
)
TORCH_CHECK
(
hd_out
==
kHdOut
,
"required mat_b.shape[1] == 2112"
)
TORCH_CHECK
(
output
.
size
(
0
)
==
num_tokens
,
"required output.shape[0] == mat_a.shape[0]"
)
TORCH_CHECK
(
output
.
size
(
1
)
==
hd_out
,
"required output.shape[1] == mat_b.shape[1]"
)
TORCH_CHECK
(
mat_a
.
strides
()[
1
]
==
1
);
// Row-major
TORCH_CHECK
(
output
.
strides
()[
1
]
==
1
);
// Row-major
TORCH_CHECK
(
mat_b
.
strides
()[
0
]
==
1
);
// Column-major
auto
const
data_type
=
mat_a
.
scalar_type
();
TORCH_CHECK
(
mat_a
.
scalar_type
()
==
torch
::
kBFloat16
&&
mat_b
.
scalar_type
()
==
torch
::
kBFloat16
,
"Only BFloat16 input dtype is supported"
)
TORCH_CHECK
(
output
.
scalar_type
()
==
torch
::
kBFloat16
,
"Only BFloat16 output dtype is supported"
)
auto
const
sm
=
getSMVersion
();
TORCH_CHECK
(
sm
>=
90
,
"required CUDA ARCH >= SM_90"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
if
(
num_tokens
<=
8
)
{
invokeFusedAGemm
<
__nv_bfloat16
,
kHdIn
,
kHdOut
,
8
>
(
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
num_tokens
,
stream
);
}
else
{
invokeFusedAGemm
<
__nv_bfloat16
,
kHdIn
,
kHdOut
,
16
>
(
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
num_tokens
,
stream
);
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
04b35190
...
...
@@ -201,6 +201,8 @@ void bmm_fp8(
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
/*
* From csrc/moe
*/
...
...
sgl-kernel/include/utils.h
View file @
04b35190
...
...
@@ -241,6 +241,23 @@ inline int getSMVersion() {
return
sm_major
*
10
+
sm_minor
;
}
inline
bool
getBoolEnv
(
char
const
*
name
)
{
char
const
*
env
=
std
::
getenv
(
name
);
return
env
&&
env
[
0
]
==
'1'
&&
env
[
1
]
==
'\0'
;
}
inline
bool
getEnvEnablePDL
()
{
static
std
::
once_flag
flag
;
static
bool
enablePDL
=
false
;
std
::
call_once
(
flag
,
[
&
]()
{
if
(
getSMVersion
()
>=
90
)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL
=
getBoolEnv
(
"TRTLLM_ENABLE_PDL"
);
}
});
return
enablePDL
;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
04b35190
...
...
@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
awq_dequantize
,
bmm_fp8
,
cutlass_scaled_fp4_mm
,
dsv3_fused_a_gemm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
04b35190
...
...
@@ -82,6 +82,21 @@ def bmm_fp8(
return
out
def
dsv3_fused_a_gemm
(
mat_a
:
torch
.
Tensor
,
mat_b
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
None
:
output
=
torch
.
empty
(
(
mat_a
.
shape
[
0
],
mat_b
.
shape
[
1
]),
device
=
mat_a
.
device
,
dtype
=
mat_a
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
dsv3_fused_a_gemm
.
default
(
output
,
mat_a
,
mat_b
)
return
output
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
dsv3_fused_a_gemm
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
i
+
1
for
i
in
range
(
16
)])
def
test_dsv3_fused_a_gemm
(
num_tokens
):
kHdIn
=
7168
kHdOut
=
2112
mat_a
=
torch
.
randn
(
(
num_tokens
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
kHdOut
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
output
=
torch
.
empty
(
(
num_tokens
,
kHdOut
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
ref
=
F
.
linear
(
mat_a
,
mat_b
.
T
)
output
=
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
assert
torch
.
allclose
(
output
,
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Fused GEMM output mismatch with torch.nn.functional.linear reference"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment