Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0194948f
Unverified
Commit
0194948f
authored
Mar 02, 2025
by
Stefan He
Committed by
GitHub
Mar 02, 2025
Browse files
Optimize Triton Kernel of Group GEMM in DeepGEMM Benchmark (#4014)
parent
b4d34cd3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
102 deletions
+102
-102
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
.../kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
+102
-102
No files found.
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
View file @
0194948f
...
...
@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel(
):
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Note: Block sizes must be multiples of 32 for optimal TMA performance.
"""
# Map program ids to the block of C it should compute
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
pid_group
=
tl
.
program_id
(
axis
=
0
)
# Group ID
pid_n
=
tl
.
program_id
(
axis
=
1
)
# N dimension ID
# Compute the M block ID within this group
group_size_m
=
min
(
M
-
pid_group
*
GROUP_SIZE_M
,
GROUP_SIZE_M
)
pid_m_within_group
=
tl
.
program_id
(
axis
=
2
)
%
group_size_m
pid_m
=
pid_group
*
GROUP_SIZE_M
+
pid_m_within_group
# Create pointers for the first blocks of A and B
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
...
...
@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel(
pid_n
*
stride_b_scale_n
+
k_block
*
stride_b_scale_k
)
# Perform matrix multiplication in FP8
res
=
tl
.
dot
(
a
,
b
)
# Load scaling factors for the current block
a_scale
=
tl
.
load
(
a_scale_ptrs
)[:,
None
]
# [BLOCK_SIZE_M, 1]
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Convert FP8 to FP32 for computation
a
=
a
.
to
(
tl
.
float32
)
b
=
b
.
to
(
tl
.
float32
)
# Apply scaling factors to the current block
a
=
a
*
a_scale
b
=
b
*
b_scale
# Accumulate matmul for the current block
accumulator
+=
tl
.
dot
(
a
,
b
)
# Apply scaling factors to the accumulated result
accumulator
+=
res
*
a_scale
*
b_scale
# Advance pointers
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
...
...
@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
fp8_gemm_group_triton
(
a_tuple
,
b_tuple
,
num_groups
):
def
fp8_gemm_group_triton
(
a_tuple
,
b_tuple
,
c
,
num_groups
):
"""
Perform matrix multiplication with FP8 inputs and proper scaling.
Args:
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
c: Output tensor in BF16 format
num_groups: Number of groups for grouped GEMM
Returns:
...
...
@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
a
,
a_scale
=
a_tuple
b
,
b_scale
=
b_tuple
# Check constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
1
],
"Incompatible dimensions"
assert
a
.
is_contiguous
(),
"Matrix A must be contiguous"
M
,
K
=
a
.
shape
N
,
K_b
=
b
.
shape
assert
K
==
K_b
,
f
"Incompatible K dimensions:
{
K
}
vs
{
K_b
}
"
# Transpose b to match kernel expectations (K,N format)
b
=
b
.
T
.
contiguous
()
_
,
N
=
b
.
shape
# Allocate output in bfloat16 (not float16)
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
# Configure block sizes - must be multiples of 32 for TMA alignment
BLOCK_SIZE_M
=
128
BLOCK_SIZE_N
=
128
BLOCK_SIZE_K
=
128
#
Prepare scale factor
s
# Ensure scales are in the right format and contiguous
a_scale
=
a_scale
.
contiguous
(
)
b_scale
=
b_scale
.
contiguous
(
)
#
Calculate grid dimension
s
num_pid_m
=
triton
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
triton
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_groups_grid
=
triton
.
cdiv
(
num_pid_m
,
num_groups
)
# 1D launch kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
# Calculate K blocks (128 elements per block)
K_blocks
=
triton
.
cdiv
(
K
,
128
)
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
grid
=
(
num_groups_grid
,
num_pid_n
,
min
(
num_groups
,
num_pid_m
))
fp8_gemm_group_triton_kernel
[
grid
](
a
,
...
...
@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
1
,
# Stride in the K dimension may be 1
b_scale
.
stride
(
0
),
1
if
b_scale
.
dim
()
>
1
else
0
,
BLOCK_SIZE_M
=
128
,
BLOCK_SIZE_N
=
128
,
BLOCK_SIZE_K
=
128
,
BLOCK_SIZE_M
=
BLOCK_SIZE_M
,
BLOCK_SIZE_N
=
BLOCK_SIZE_N
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
GROUP_SIZE_M
=
num_groups
,
)
...
...
@@ -264,52 +249,6 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
return
out
def
get_weight_shapes
(
tp_size
):
# cannot TP
total
=
[
(
512
+
64
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
weight_shapes
=
[]
for
t
in
total
:
weight_shapes
.
append
(
t
)
for
n_t
in
n_tp
:
new_t
=
(
n_t
[
0
]
//
tp_size
,
n_t
[
1
])
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
(
k_t
[
0
],
k_t
[
1
]
//
tp_size
)
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
create_benchmark_configs
(
tp_size
):
configs
=
[]
weight_shapes
=
get_weight_shapes
(
tp_size
)
batch_sizes
=
[
2048
,
4096
]
group_sizes
=
[
4
,
8
]
for
n
,
k
in
weight_shapes
:
for
m
in
batch_sizes
:
for
num_groups
in
group_sizes
:
configs
.
append
((
m
,
n
,
k
,
num_groups
,
tp_size
))
return
configs
def
calculate_diff
(
m
:
int
,
n
:
int
,
k
:
int
,
num_groups
:
int
):
print
(
f
"Shape (m=
{
m
}
, n=
{
n
}
, k=
{
k
}
"
)
x
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
...
...
@@ -332,8 +271,16 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
)
torch
.
cuda
.
synchronize
()
# Quantized x and y
out_triton
=
fp8_gemm_group_triton
(
x_fp8_flat
,
y_fp8_flat
,
num_groups
)
# Prepare inputs for Triton
a
,
a_scale
=
x_fp8_flat
b
,
b_scale
=
y_fp8_flat
b
=
b
.
T
.
contiguous
()
# Ensure scales are in the right format and contiguous
a_scale
,
b_scale
=
a_scale
.
contiguous
(),
b_scale
.
contiguous
()
M
,
_
=
a
.
shape
_
,
N
=
b
.
shape
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
out_triton
=
fp8_gemm_group_triton
((
a
,
a_scale
),
(
b
,
b_scale
),
c
,
num_groups
)
torch
.
cuda
.
synchronize
()
diff_torch_deepgemm
=
torch
.
abs
(
out_torch
-
out_deepgemm
).
mean
().
item
()
...
...
@@ -369,6 +316,52 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
)
def
get_weight_shapes
(
tp_size
):
# cannot TP
total
=
[
(
512
+
64
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
weight_shapes
=
[]
for
t
in
total
:
weight_shapes
.
append
(
t
)
for
n_t
in
n_tp
:
new_t
=
(
n_t
[
0
]
//
tp_size
,
n_t
[
1
])
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
(
k_t
[
0
],
k_t
[
1
]
//
tp_size
)
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
create_benchmark_configs
(
tp_size
):
configs
=
[]
weight_shapes
=
get_weight_shapes
(
tp_size
)
batch_sizes
=
[
2048
,
4096
]
group_sizes
=
[
4
,
8
]
for
n
,
k
in
weight_shapes
:
for
m
in
batch_sizes
:
for
num_groups
in
group_sizes
:
configs
.
append
((
m
,
n
,
k
,
num_groups
,
tp_size
))
return
configs
def
get_benchmark
(
tp_size
):
all_configs
=
create_benchmark_configs
(
tp_size
)
...
...
@@ -416,10 +409,21 @@ def get_benchmark(tp_size):
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
# Prepare inputs for Triton
# We did it outside of the lambda function to make it fair comparison like deepgemm
a
,
a_scale
=
x_fp8_flat
b
,
b_scale
=
y_fp8_flat
b
=
b
.
T
.
contiguous
()
# Ensure scales are in the right format and contiguous
a_scale
,
b_scale
=
a_scale
.
contiguous
(),
b_scale
.
contiguous
()
M
,
_
=
a
.
shape
_
,
N
=
b
.
shape
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_gemm_group_triton
(
x_fp8_flat
,
y_fp8_flat
,
(
a
,
a_scale
),
(
b
,
b_scale
),
c
,
num_groups
,
),
quantiles
=
quantiles
,
...
...
@@ -429,13 +433,8 @@ def get_benchmark(tp_size):
flops
=
2
*
m
*
n
*
k
# multiply-adds
tflops
=
flops
/
(
ms
*
1e-3
)
/
1e12
# Print shape-specific results with TFLOPS
print
(
f
"Time:
{
ms
:.
2
f
}
ms, TFLOPS:
{
tflops
:.
2
f
}
"
)
return
(
ms
,
max_ms
,
min_ms
,
)
# return in seconds for consistency with triton benchmark
print
(
f
"Time:
{
ms
*
1000
:.
2
f
}
ms, TFLOPS:
{
tflops
:.
2
f
}
"
)
return
ms
*
1000
,
max_ms
*
1000
,
min_ms
*
1000
# convert to ms
return
benchmark
...
...
@@ -478,6 +477,7 @@ if __name__ == "__main__":
calculate_diff
(
8192
,
2048
,
7168
,
4
)
calculate_diff
(
4096
,
7168
,
4096
,
8
)
calculate_diff
(
4096
,
2048
,
7168
,
8
)
calculate_diff
(
4096
,
576
,
7168
,
8
)
# Get the benchmark function with the specified tp_size
benchmark
=
get_benchmark
(
args
.
tp_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment