Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
136b8e6a
Unverified
Commit
136b8e6a
authored
Apr 11, 2025
by
Yineng Zhang
Committed by
GitHub
Apr 11, 2025
Browse files
fix: remove cublas_grouped_gemm (#5307)
parent
034c5256
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
0 additions
and
508 deletions
+0
-508
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+0
-1
sgl-kernel/benchmark/bench_cublas_grouped_gemm.py
sgl-kernel/benchmark/bench_cublas_grouped_gemm.py
+0
-262
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+0
-5
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
+0
-172
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+0
-7
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+0
-1
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+0
-20
sgl-kernel/tests/test_cublas_grouped_gemm.py
sgl-kernel/tests/test_cublas_grouped_gemm.py
+0
-40
No files found.
sgl-kernel/CMakeLists.txt
View file @
136b8e6a
...
@@ -164,7 +164,6 @@ set(SOURCES
...
@@ -164,7 +164,6 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/cublas_grouped_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_cublas_grouped_gemm.py
deleted
100644 → 0
View file @
034c5256
import
argparse
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
cublas_grouped_gemm
WEIGHT_CONFIGS
=
{
"DeepSeek-V2-Lite"
:
{
"num_routed_experts"
:
64
,
"ffn_shapes"
:
[
[
2048
,
2816
],
[
1408
,
2048
],
],
},
"DeepSeek-V2"
:
{
"num_routed_experts"
:
160
,
"ffn_shapes"
:
[
[
5120
,
3072
],
[
1536
,
5120
],
],
},
}
# This Triton Grouped Gemm Kernel is adapted from
# https://triton-lang.org/main/getting-started/tutorials/08-grouped-gemm.html
@
triton
.
jit
def
grouped_matmul_kernel
(
# device tensor of matrices pointers
group_a_ptrs
,
group_b_ptrs
,
group_c_ptrs
,
# device tensor of gemm sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
group_gemm_sizes
,
# device tensor of leading dimension sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
g_lds
,
# Factors for multiplication.
alphas
,
betas
,
# number of gemms
group_size
,
# number of virtual SM
NUM_SM
:
tl
.
constexpr
,
# tile sizes
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
):
tile_idx
=
tl
.
program_id
(
0
)
last_problem_end
=
0
for
g
in
range
(
group_size
):
# get the gemm size of the current problem
gm
=
tl
.
load
(
group_gemm_sizes
+
g
*
3
)
gn
=
tl
.
load
(
group_gemm_sizes
+
g
*
3
+
1
)
gk
=
tl
.
load
(
group_gemm_sizes
+
g
*
3
+
2
)
num_m_tiles
=
tl
.
cdiv
(
gm
,
BLOCK_SIZE_M
)
num_n_tiles
=
tl
.
cdiv
(
gn
,
BLOCK_SIZE_N
)
num_tiles
=
num_m_tiles
*
num_n_tiles
# load multiplication factors
alpha
=
tl
.
load
(
alphas
+
g
)
beta
=
tl
.
load
(
betas
+
g
)
# iterate through the tiles in the current gemm problem
while
tile_idx
>=
last_problem_end
and
tile_idx
<
last_problem_end
+
num_tiles
:
# pick up a tile from the current gemm problem
k
=
gk
lda
=
tl
.
load
(
g_lds
+
g
*
3
)
ldb
=
tl
.
load
(
g_lds
+
g
*
3
+
1
)
ldc
=
tl
.
load
(
g_lds
+
g
*
3
+
2
)
a_ptr
=
tl
.
load
(
group_a_ptrs
+
g
).
to
(
tl
.
pointer_type
(
tl
.
float16
))
b_ptr
=
tl
.
load
(
group_b_ptrs
+
g
).
to
(
tl
.
pointer_type
(
tl
.
float16
))
c_ptr
=
tl
.
load
(
group_c_ptrs
+
g
).
to
(
tl
.
pointer_type
(
tl
.
float16
))
# figure out tile coordinates
tile_idx_in_gemm
=
tile_idx
-
last_problem_end
tile_m_idx
=
tile_idx_in_gemm
//
num_n_tiles
tile_n_idx
=
tile_idx_in_gemm
%
num_n_tiles
# do regular gemm here
offs_am
=
tile_m_idx
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_bn
=
tile_n_idx
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
offs_am
[:,
None
]
*
lda
+
offs_k
[
None
,
:]
b_ptrs
=
b_ptr
+
offs_k
[:,
None
]
*
ldb
+
offs_bn
[
None
,
:]
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
kk
in
range
(
0
,
tl
.
cdiv
(
k
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
(
offs_am
[:,
None
]
<
gm
)
and
(
offs_k
[
None
,
:]
<
gk
-
kk
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
(
offs_k
[:,
None
]
<
gk
-
kk
*
BLOCK_SIZE_K
)
and
(
offs_bn
[
None
,
:]
<
gn
),
other
=
0.0
,
)
accumulator
+=
tl
.
dot
(
a
,
b
)
a_ptrs
+=
BLOCK_SIZE_K
b_ptrs
+=
BLOCK_SIZE_K
*
ldb
accumulator
*=
alpha
c
=
accumulator
.
to
(
tl
.
float16
)
offs_cm
=
tile_m_idx
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
tile_n_idx
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
ldc
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
output_mask
=
(
offs_am
[:,
None
]
<
gm
)
and
(
offs_bn
[
None
,
:]
<
gn
)
c
+=
beta
*
tl
.
load
(
c_ptrs
,
mask
=
output_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
output_mask
)
# go to the next tile by advancing NUM_SM
tile_idx
+=
NUM_SM
# get ready to go to the next gemm problem
last_problem_end
=
last_problem_end
+
num_tiles
def
triton_perf_fn
(
group_A
,
group_B
,
group_C
,
dtype
):
# We put the process of matrix lengths and pointers here out of fairness,
# since cublas_grouped_gemm kernel also does these work.
group_size
=
len
(
group_A
)
A_addrs
=
[]
B_addrs
=
[]
C_addrs
=
[]
g_sizes
=
[]
g_lds
=
[]
alphas
=
[
1.0
]
*
group_size
betas
=
[
0.0
]
*
group_size
for
i
in
range
(
group_size
):
M
,
N
,
K
=
group_A
[
i
].
shape
[
0
],
group_B
[
i
].
shape
[
1
],
group_A
[
i
].
shape
[
1
]
g_sizes
+=
[
M
,
N
,
K
]
g_lds
+=
[
K
,
N
,
N
]
A_addrs
.
append
(
group_A
[
i
].
data_ptr
())
B_addrs
.
append
(
group_B
[
i
].
data_ptr
())
C_addrs
.
append
(
group_C
[
i
].
data_ptr
())
d_a_ptrs
=
torch
.
tensor
(
A_addrs
,
device
=
"cuda"
)
d_b_ptrs
=
torch
.
tensor
(
B_addrs
,
device
=
"cuda"
)
d_c_ptrs
=
torch
.
tensor
(
C_addrs
,
device
=
"cuda"
)
d_g_sizes
=
torch
.
tensor
(
g_sizes
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
d_g_lds
=
torch
.
tensor
(
g_lds
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
d_alphas
=
torch
.
tensor
(
alphas
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
d_betas
=
torch
.
tensor
(
betas
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
NUM_SM
=
128
grid
=
(
NUM_SM
,)
grouped_matmul_kernel
[
grid
](
d_a_ptrs
,
d_b_ptrs
,
d_c_ptrs
,
d_g_sizes
,
d_g_lds
,
d_alphas
,
d_betas
,
group_size
,
NUM_SM
=
NUM_SM
,
BLOCK_SIZE_M
=
128
,
BLOCK_SIZE_N
=
128
,
BLOCK_SIZE_K
=
32
,
)
def
cublas_perf_fn
(
group_A
,
group_B
,
group_C
,
dtype
):
cublas_grouped_gemm
(
group_A
,
group_B
,
group_C
,
dtype
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"M"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"cublas"
,
],
line_names
=
[
"triton"
,
"cublas"
,
],
styles
=
[(
"green"
,
"-"
),
(
"blue"
,
"-"
)],
ylabel
=
"gbps"
,
plot_name
=
"grouped gemm"
,
args
=
{},
)
)
def
benchmark
(
M
,
provider
,
N
,
K
):
group_size
=
20
# Number of used experts per gpu is usually around 20
group_A
=
[]
group_B_row_major
=
[]
group_B_col_major
=
[]
group_C
=
[]
dtype
=
torch
.
float16
for
i
in
range
(
group_size
):
A
=
torch
.
rand
((
M
,
K
),
device
=
"cuda"
,
dtype
=
dtype
)
B_row_major
=
torch
.
rand
((
K
,
N
),
device
=
"cuda"
,
dtype
=
dtype
)
B_col_major
=
torch
.
rand
((
N
,
K
),
device
=
"cuda"
,
dtype
=
dtype
)
C
=
torch
.
empty
((
M
,
N
),
device
=
"cuda"
,
dtype
=
dtype
)
group_A
.
append
(
A
)
group_B_row_major
.
append
(
B_row_major
)
group_B_col_major
.
append
(
B_col_major
)
group_C
.
append
(
C
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
"triton"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
triton_perf_fn
(
group_A
,
group_B_row_major
,
group_C
,
dtype
),
quantiles
=
quantiles
,
)
elif
"cublas"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cublas_perf_fn
(
group_A
,
group_B_col_major
,
group_C
,
dtype
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
group_size
*
(
2
*
M
*
N
*
K
+
2
*
M
*
N
)
*
group_A
[
0
].
element_size
()
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"DeepSeek-V2"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
8
,
help
=
"Tensor parallel size"
,
)
args
=
parser
.
parse_args
()
for
model
in
args
.
models
:
assert
model
in
WEIGHT_CONFIGS
num_experts_per_device
=
(
WEIGHT_CONFIGS
[
model
][
"num_routed_experts"
]
//
args
.
tp_size
)
for
K
,
N
in
WEIGHT_CONFIGS
[
model
][
"ffn_shapes"
]:
print
(
f
"
{
model
}
N=
{
N
}
K=
{
K
}
tp_size=
{
args
.
tp_size
}
"
f
"group_size=num_experts_per_device=
{
num_experts_per_device
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_grouped_gemm_res"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/csrc/common_extension.cc
View file @
136b8e6a
...
@@ -112,11 +112,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -112,11 +112,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"
);
m
.
def
(
"sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"
);
m
.
impl
(
"sgl_per_token_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_quant_fp8
);
m
.
impl
(
"sgl_per_token_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_quant_fp8
);
m
.
def
(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"
);
m
.
impl
(
"cublas_grouped_gemm"
,
torch
::
kCUDA
,
&
cublas_grouped_gemm
);
m
.
def
(
m
.
def
(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor block_scale_a, Tensor block_scale_b,"
...
...
sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu
deleted
100644 → 0
View file @
034c5256
// References:
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu
// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <cublas_v2.h>
#include <cudaTypedefs.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include "utils.h"
static
void
check_group_count
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
)
{
TORCH_CHECK
(
((
inputs
.
size
()
==
weights
.
size
())
&&
(
inputs
.
size
()
==
outputs
.
size
())),
"The group count of inputs, weights and outputs should be the same."
);
}
static
void
check_device_dtype
(
const
torch
::
Dtype
&
dtype
,
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
for
(
const
auto
&
t
:
tensors
)
{
TORCH_CHECK
(
dtype
==
t
.
dtype
(),
"dtype of all the tensors should be the same"
);
TORCH_CHECK
(
t
.
is_cuda
(),
"All tensors should be in Cuda memory"
);
}
}
static
std
::
vector
<
int
>
get_dims
(
const
std
::
vector
<
torch
::
Tensor
>&
tensors
,
int
dim
)
{
std
::
vector
<
int
>
results
;
for
(
const
auto
&
t
:
tensors
)
{
TORCH_CHECK
(
t
.
dim
()
==
2
,
"Should pass in 2D matrices"
);
results
.
push_back
(
t
.
size
(
dim
));
}
return
std
::
move
(
results
);
}
static
std
::
vector
<
int
>
get_strides
(
const
std
::
vector
<
torch
::
Tensor
>&
tensors
,
int
dim
)
{
std
::
vector
<
int
>
results
;
for
(
const
auto
&
t
:
tensors
)
{
results
.
push_back
(
t
.
stride
(
dim
));
}
return
std
::
move
(
results
);
}
static
void
check_equal
(
const
std
::
vector
<
int
>&
a
,
const
std
::
vector
<
int
>&
b
,
const
std
::
string
&
err_msg
)
{
for
(
int
i
=
0
;
i
<
a
.
size
();
++
i
)
{
TORCH_CHECK
(
a
[
i
]
==
b
[
i
],
err_msg
);
}
}
static
std
::
vector
<
void
*>
get_tensor_ptrs
(
const
std
::
vector
<
torch
::
Tensor
>&
tensors
)
{
std
::
vector
<
void
*>
ptrs
;
for
(
auto
&
t
:
tensors
)
{
ptrs
.
push_back
(
t
.
data_ptr
());
}
return
std
::
move
(
ptrs
);
}
static
torch
::
Tensor
create_ptr_pointer
(
const
std
::
vector
<
void
*>&
ptrs
,
cudaStream_t
stream
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kDouble
).
device
(
torch
::
kCUDA
);
torch
::
Tensor
gpu_ptrs
=
torch
::
empty
({
static_cast
<
int
>
(
ptrs
.
size
())},
options
);
TORCH_CHECK
(
cudaMemcpyAsync
(
gpu_ptrs
.
data_ptr
(),
ptrs
.
data
(),
sizeof
(
void
*
)
*
ptrs
.
size
(),
cudaMemcpyHostToDevice
,
stream
)
==
CUBLAS_STATUS_SUCCESS
);
return
gpu_ptrs
;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
// b: (m, k) row major = (k, m) col major
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
// a: (n, k) row major = (n, k)^T col major
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
// c: (m, n) row major = (n, m) col major
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
TORCH_CHECK
(
out_dtype
==
torch
::
kHalf
||
out_dtype
==
torch
::
kBFloat16
,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype"
);
int
group_count
=
inputs
.
size
();
check_group_count
(
inputs
,
weights
,
outputs
);
std
::
vector
<
int
>
group_size
(
group_count
,
1
);
// Make sure all tensors are on cuda and use the same dtype
check_device_dtype
(
out_dtype
,
inputs
);
check_device_dtype
(
out_dtype
,
weights
);
check_device_dtype
(
out_dtype
,
outputs
);
// Weights should be transposed to (n, k) of column major
std
::
vector
<
cublasOperation_t
>
transa_array
(
group_count
,
CUBLAS_OP_T
);
std
::
vector
<
cublasOperation_t
>
transb_array
(
group_count
,
CUBLAS_OP_N
);
// Get dim arrays
std
::
vector
<
int
>
m_array
=
get_dims
(
weights
,
0
);
std
::
vector
<
int
>
n_array
=
get_dims
(
inputs
,
0
);
std
::
vector
<
int
>
k_array
=
get_dims
(
inputs
,
1
);
// Make sure the dimensions in each group match
std
::
vector
<
int
>
m_array1
=
get_dims
(
outputs
,
1
);
std
::
vector
<
int
>
n_array1
=
get_dims
(
outputs
,
0
);
std
::
vector
<
int
>
k_array1
=
get_dims
(
weights
,
1
);
check_equal
(
m_array
,
m_array1
,
"sizes don't match on m dimension"
);
check_equal
(
n_array
,
n_array1
,
"sizes don't match on n dimension"
);
check_equal
(
k_array
,
k_array1
,
"sizes don't match on k dimension"
);
// Get leading dimensions
std
::
vector
<
int
>
lda_array
=
get_strides
(
weights
,
0
);
std
::
vector
<
int
>
ldb_array
=
get_strides
(
inputs
,
0
);
std
::
vector
<
int
>
ldc_array
=
get_strides
(
outputs
,
0
);
// Use default scaling factors
std
::
vector
<
float
>
alpha_array
(
group_count
,
1
);
std
::
vector
<
float
>
beta_array
(
group_count
,
0
);
std
::
vector
<
void
*>
a_array
=
get_tensor_ptrs
(
weights
);
std
::
vector
<
void
*>
b_array
=
get_tensor_ptrs
(
inputs
);
std
::
vector
<
void
*>
c_array
=
get_tensor_ptrs
(
outputs
);
auto
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
// Should allocate tensors for storage of pointers
torch
::
Tensor
d_a
=
create_ptr_pointer
(
a_array
,
stream
);
torch
::
Tensor
d_b
=
create_ptr_pointer
(
b_array
,
stream
);
torch
::
Tensor
d_c
=
create_ptr_pointer
(
c_array
,
stream
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto
handle
=
reinterpret_cast
<
cublasHandle_t
>
(
cublas_handle
);
cudaDataType_t
cuda_data_type
=
(
out_dtype
==
torch
::
kHalf
?
CUDA_R_16F
:
CUDA_R_16BF
);
auto
status
=
cublasGemmGroupedBatchedEx
(
handle
,
transa_array
.
data
(),
transb_array
.
data
(),
m_array
.
data
(),
n_array
.
data
(),
k_array
.
data
(),
alpha_array
.
data
(),
(
void
**
)
d_a
.
data_ptr
(),
cuda_data_type
,
lda_array
.
data
(),
(
void
**
)
d_b
.
data_ptr
(),
cuda_data_type
,
ldb_array
.
data
(),
beta_array
.
data
(),
(
void
**
)
d_c
.
data_ptr
(),
cuda_data_type
,
ldc_array
.
data
(),
group_count
,
group_size
.
data
(),
CUBLAS_COMPUTE_32F
);
TORCH_CHECK
(
status
==
CUBLAS_STATUS_SUCCESS
,
"cublas grouped gemm failed: "
,
cublasGetStatusString
(
status
));
TORCH_CHECK
(
cudaStreamSynchronize
(
stream
)
==
cudaSuccess
,
"Failed when stream synchronization"
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"Cublas GroupGemm is not implemented with current compute capability: "
,
getSMVersion
());
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
136b8e6a
...
@@ -160,13 +160,6 @@ void sgl_per_token_group_quant_int8(
...
@@ -160,13 +160,6 @@ void sgl_per_token_group_quant_int8(
double
int8_max
);
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
bmm_fp8
(
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
B
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
136b8e6a
...
@@ -25,7 +25,6 @@ from sgl_kernel.elementwise import (
...
@@ -25,7 +25,6 @@ from sgl_kernel.elementwise import (
from
sgl_kernel.gemm
import
(
from
sgl_kernel.gemm
import
(
awq_dequantize
,
awq_dequantize
,
bmm_fp8
,
bmm_fp8
,
cublas_grouped_gemm
,
cutlass_scaled_fp4_mm
,
cutlass_scaled_fp4_mm
,
fp8_blockwise_scaled_mm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
136b8e6a
...
@@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8(
...
@@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8(
)
)
def
cublas_grouped_gemm
(
inputs
:
List
[
torch
.
Tensor
],
weights
:
List
[
torch
.
Tensor
],
outputs
:
List
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
)
->
None
:
assert
(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
.
default
(
inputs
,
weights
,
outputs
,
out_dtype
,
cublas_handle
,
get_cuda_stream
(),
)
def
sgl_per_token_quant_fp8
(
def
sgl_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_cublas_grouped_gemm.py
deleted
100644 → 0
View file @
034c5256
import
pytest
import
torch
from
sgl_kernel
import
cublas_grouped_gemm
def
torch_grouped_gemm
(
a_array
,
b_array
,
out_dtype
):
return
[
torch
.
matmul
(
a
,
b
.
t
()).
to
(
out_dtype
)
for
a
,
b
in
zip
(
a_array
,
b_array
)]
skip_condition
=
not
torch
.
cuda
.
is_available
()
or
(
torch
.
version
.
cuda
is
None
or
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
<
(
12
,
5
)
)
@
pytest
.
mark
.
skipif
(
skip_condition
,
reason
=
"CUDA not available or CUDA version lower than 12.5"
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
16
,
32
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
2
,
16
,
128
,
256
,
4096
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
3
,
16
,
32
,
512
,
8192
])
def
test_grouped_gemm_accuracy
(
out_dtype
,
M
,
N
,
K
):
a
=
torch
.
randn
((
M
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
b
=
torch
.
randn
((
N
,
K
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
5
expected
=
torch
.
matmul
(
a
,
b
.
t
()).
to
(
out_dtype
)
a_array
=
[
a
]
b_array
=
[
b
]
c_array
=
[
torch
.
empty
((
M
,
N
),
device
=
"cuda"
,
dtype
=
out_dtype
)]
result_torch
=
torch_grouped_gemm
(
a_array
,
b_array
,
out_dtype
)[
0
]
cublas_grouped_gemm
(
a_array
,
b_array
,
c_array
,
out_dtype
)
torch
.
testing
.
assert_close
(
result_torch
,
expected
)
torch
.
testing
.
assert_close
(
c_array
[
0
],
expected
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment