Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0194948f
Unverified
Commit
0194948f
authored
Mar 02, 2025
by
Stefan He
Committed by
GitHub
Mar 02, 2025
Browse files
Optimize Triton Kernel of Group GEMM in DeepGEMM Benchmark (#4014)
parent
b4d34cd3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
102 deletions
+102
-102
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
.../kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
+102
-102
No files found.
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py
View file @
0194948f
...
@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel(
...
@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel(
):
):
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Note: Block sizes must be multiples of 32 for optimal TMA performance.
"""
"""
# Map program ids to the block of C it should compute
# Map program ids to the block of C it should compute
pid
=
tl
.
program_id
(
axis
=
0
)
pid_group
=
tl
.
program_id
(
axis
=
0
)
# Group ID
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
pid_n
=
tl
.
program_id
(
axis
=
1
)
# N dimension ID
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
# Compute the M block ID within this group
group_id
=
pid
//
num_pid_in_group
group_size_m
=
min
(
M
-
pid_group
*
GROUP_SIZE_M
,
GROUP_SIZE_M
)
first_pid_m
=
group_id
*
GROUP_SIZE_M
pid_m_within_group
=
tl
.
program_id
(
axis
=
2
)
%
group_size_m
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
pid_group
*
GROUP_SIZE_M
+
pid_m_within_group
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# Create pointers for the first blocks of A and B
# Create pointers for the first blocks of A and B
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
...
@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel(
...
@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel(
pid_n
*
stride_b_scale_n
+
k_block
*
stride_b_scale_k
pid_n
*
stride_b_scale_n
+
k_block
*
stride_b_scale_k
)
)
# Perform matrix multiplication in FP8
res
=
tl
.
dot
(
a
,
b
)
# Load scaling factors for the current block
# Load scaling factors for the current block
a_scale
=
tl
.
load
(
a_scale_ptrs
)[:,
None
]
# [BLOCK_SIZE_M, 1]
a_scale
=
tl
.
load
(
a_scale_ptrs
)[:,
None
]
# [BLOCK_SIZE_M, 1]
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Convert FP8 to FP32 for computation
# Apply scaling factors to the accumulated result
a
=
a
.
to
(
tl
.
float32
)
accumulator
+=
res
*
a_scale
*
b_scale
b
=
b
.
to
(
tl
.
float32
)
# Apply scaling factors to the current block
a
=
a
*
a_scale
b
=
b
*
b_scale
# Accumulate matmul for the current block
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance pointers
# Advance pointers
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
...
@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel(
...
@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
fp8_gemm_group_triton
(
a_tuple
,
b_tuple
,
num_groups
):
def
fp8_gemm_group_triton
(
a_tuple
,
b_tuple
,
c
,
num_groups
):
"""
"""
Perform matrix multiplication with FP8 inputs and proper scaling.
Perform matrix multiplication with FP8 inputs and proper scaling.
Args:
Args:
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
c: Output tensor in BF16 format
num_groups: Number of groups for grouped GEMM
num_groups: Number of groups for grouped GEMM
Returns:
Returns:
...
@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
...
@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
a
,
a_scale
=
a_tuple
a
,
a_scale
=
a_tuple
b
,
b_scale
=
b_tuple
b
,
b_scale
=
b_tuple
# Check constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
1
],
"Incompatible dimensions"
assert
a
.
is_contiguous
(),
"Matrix A must be contiguous"
M
,
K
=
a
.
shape
M
,
K
=
a
.
shape
N
,
K_b
=
b
.
shape
_
,
N
=
b
.
shape
assert
K
==
K_b
,
f
"Incompatible K dimensions:
{
K
}
vs
{
K_b
}
"
# Transpose b to match kernel expectations (K,N format)
b
=
b
.
T
.
contiguous
()
# Allocate output in bfloat16 (not float16)
# Configure block sizes - must be multiples of 32 for TMA alignment
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
BLOCK_SIZE_M
=
128
BLOCK_SIZE_N
=
128
BLOCK_SIZE_K
=
128
#
Prepare scale factor
s
#
Calculate grid dimension
s
# Ensure scales are in the right format and contiguous
num_pid_m
=
triton
.
cdiv
(
M
,
BLOCK_SIZE_M
)
a_scale
=
a_scale
.
contiguous
(
)
num_pid_n
=
triton
.
cdiv
(
N
,
BLOCK_SIZE_N
)
b_scale
=
b_scale
.
contiguous
(
)
num_groups_grid
=
triton
.
cdiv
(
num_pid_m
,
num_groups
)
# 1D launch kernel
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
grid
=
lambda
META
:
(
grid
=
(
num_groups_grid
,
num_pid_n
,
min
(
num_groups
,
num_pid_m
))
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
# Calculate K blocks (128 elements per block)
K_blocks
=
triton
.
cdiv
(
K
,
128
)
fp8_gemm_group_triton_kernel
[
grid
](
fp8_gemm_group_triton_kernel
[
grid
](
a
,
a
,
...
@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
...
@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
1
,
# Stride in the K dimension may be 1
1
,
# Stride in the K dimension may be 1
b_scale
.
stride
(
0
),
b_scale
.
stride
(
0
),
1
if
b_scale
.
dim
()
>
1
else
0
,
1
if
b_scale
.
dim
()
>
1
else
0
,
BLOCK_SIZE_M
=
128
,
BLOCK_SIZE_M
=
BLOCK_SIZE_M
,
BLOCK_SIZE_N
=
128
,
BLOCK_SIZE_N
=
BLOCK_SIZE_N
,
BLOCK_SIZE_K
=
128
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
GROUP_SIZE_M
=
num_groups
,
GROUP_SIZE_M
=
num_groups
,
)
)
...
@@ -264,52 +249,6 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
...
@@ -264,52 +249,6 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
return
out
return
out
def
get_weight_shapes
(
tp_size
):
# cannot TP
total
=
[
(
512
+
64
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
weight_shapes
=
[]
for
t
in
total
:
weight_shapes
.
append
(
t
)
for
n_t
in
n_tp
:
new_t
=
(
n_t
[
0
]
//
tp_size
,
n_t
[
1
])
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
(
k_t
[
0
],
k_t
[
1
]
//
tp_size
)
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
create_benchmark_configs
(
tp_size
):
configs
=
[]
weight_shapes
=
get_weight_shapes
(
tp_size
)
batch_sizes
=
[
2048
,
4096
]
group_sizes
=
[
4
,
8
]
for
n
,
k
in
weight_shapes
:
for
m
in
batch_sizes
:
for
num_groups
in
group_sizes
:
configs
.
append
((
m
,
n
,
k
,
num_groups
,
tp_size
))
return
configs
def
calculate_diff
(
m
:
int
,
n
:
int
,
k
:
int
,
num_groups
:
int
):
def
calculate_diff
(
m
:
int
,
n
:
int
,
k
:
int
,
num_groups
:
int
):
print
(
f
"Shape (m=
{
m
}
, n=
{
n
}
, k=
{
k
}
"
)
print
(
f
"Shape (m=
{
m
}
, n=
{
n
}
, k=
{
k
}
"
)
x
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
...
@@ -332,8 +271,16 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
...
@@ -332,8 +271,16 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Quantized x and y
# Prepare inputs for Triton
out_triton
=
fp8_gemm_group_triton
(
x_fp8_flat
,
y_fp8_flat
,
num_groups
)
a
,
a_scale
=
x_fp8_flat
b
,
b_scale
=
y_fp8_flat
b
=
b
.
T
.
contiguous
()
# Ensure scales are in the right format and contiguous
a_scale
,
b_scale
=
a_scale
.
contiguous
(),
b_scale
.
contiguous
()
M
,
_
=
a
.
shape
_
,
N
=
b
.
shape
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
out_triton
=
fp8_gemm_group_triton
((
a
,
a_scale
),
(
b
,
b_scale
),
c
,
num_groups
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
diff_torch_deepgemm
=
torch
.
abs
(
out_torch
-
out_deepgemm
).
mean
().
item
()
diff_torch_deepgemm
=
torch
.
abs
(
out_torch
-
out_deepgemm
).
mean
().
item
()
...
@@ -369,6 +316,52 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
...
@@ -369,6 +316,52 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
)
)
def
get_weight_shapes
(
tp_size
):
# cannot TP
total
=
[
(
512
+
64
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
7168
,
16384
),
(
7168
,
18432
),
]
# N can TP
n_tp
=
[
(
18432
*
2
,
7168
),
((
128
+
64
)
*
128
,
7168
),
(
128
*
(
128
+
128
),
512
),
(
24576
,
1536
),
(
4096
,
7168
),
]
# K can TP
k_tp
=
[(
7168
,
18432
),
(
7168
,
16384
),
(
7168
,
2048
)]
weight_shapes
=
[]
for
t
in
total
:
weight_shapes
.
append
(
t
)
for
n_t
in
n_tp
:
new_t
=
(
n_t
[
0
]
//
tp_size
,
n_t
[
1
])
weight_shapes
.
append
(
new_t
)
for
k_t
in
k_tp
:
new_t
=
(
k_t
[
0
],
k_t
[
1
]
//
tp_size
)
weight_shapes
.
append
(
new_t
)
return
weight_shapes
def
create_benchmark_configs
(
tp_size
):
configs
=
[]
weight_shapes
=
get_weight_shapes
(
tp_size
)
batch_sizes
=
[
2048
,
4096
]
group_sizes
=
[
4
,
8
]
for
n
,
k
in
weight_shapes
:
for
m
in
batch_sizes
:
for
num_groups
in
group_sizes
:
configs
.
append
((
m
,
n
,
k
,
num_groups
,
tp_size
))
return
configs
def
get_benchmark
(
tp_size
):
def
get_benchmark
(
tp_size
):
all_configs
=
create_benchmark_configs
(
tp_size
)
all_configs
=
create_benchmark_configs
(
tp_size
)
...
@@ -416,10 +409,21 @@ def get_benchmark(tp_size):
...
@@ -416,10 +409,21 @@ def get_benchmark(tp_size):
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
elif
provider
==
"triton"
:
elif
provider
==
"triton"
:
# Prepare inputs for Triton
# We did it outside of the lambda function to make it fair comparison like deepgemm
a
,
a_scale
=
x_fp8_flat
b
,
b_scale
=
y_fp8_flat
b
=
b
.
T
.
contiguous
()
# Ensure scales are in the right format and contiguous
a_scale
,
b_scale
=
a_scale
.
contiguous
(),
b_scale
.
contiguous
()
M
,
_
=
a
.
shape
_
,
N
=
b
.
shape
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
torch
.
bfloat16
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
fp8_gemm_group_triton
(
lambda
:
fp8_gemm_group_triton
(
x_fp8_flat
,
(
a
,
a_scale
),
y_fp8_flat
,
(
b
,
b_scale
),
c
,
num_groups
,
num_groups
,
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
...
@@ -429,13 +433,8 @@ def get_benchmark(tp_size):
...
@@ -429,13 +433,8 @@ def get_benchmark(tp_size):
flops
=
2
*
m
*
n
*
k
# multiply-adds
flops
=
2
*
m
*
n
*
k
# multiply-adds
tflops
=
flops
/
(
ms
*
1e-3
)
/
1e12
tflops
=
flops
/
(
ms
*
1e-3
)
/
1e12
# Print shape-specific results with TFLOPS
print
(
f
"Time:
{
ms
*
1000
:.
2
f
}
ms, TFLOPS:
{
tflops
:.
2
f
}
"
)
print
(
f
"Time:
{
ms
:.
2
f
}
ms, TFLOPS:
{
tflops
:.
2
f
}
"
)
return
ms
*
1000
,
max_ms
*
1000
,
min_ms
*
1000
# convert to ms
return
(
ms
,
max_ms
,
min_ms
,
)
# return in seconds for consistency with triton benchmark
return
benchmark
return
benchmark
...
@@ -478,6 +477,7 @@ if __name__ == "__main__":
...
@@ -478,6 +477,7 @@ if __name__ == "__main__":
calculate_diff
(
8192
,
2048
,
7168
,
4
)
calculate_diff
(
8192
,
2048
,
7168
,
4
)
calculate_diff
(
4096
,
7168
,
4096
,
8
)
calculate_diff
(
4096
,
7168
,
4096
,
8
)
calculate_diff
(
4096
,
2048
,
7168
,
8
)
calculate_diff
(
4096
,
2048
,
7168
,
8
)
calculate_diff
(
4096
,
576
,
7168
,
8
)
# Get the benchmark function with the specified tp_size
# Get the benchmark function with the specified tp_size
benchmark
=
get_benchmark
(
args
.
tp_size
)
benchmark
=
get_benchmark
(
args
.
tp_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment