Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3bbe11cc
Unverified
Commit
3bbe11cc
authored
Aug 21, 2025
by
Michael Goin
Committed by
GitHub
Aug 21, 2025
Browse files
[Perf] Small optimizations for silu_mul_fp8_quant_deep_gemm (#23265)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
c5041f89
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
32 deletions
+107
-32
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+77
-0
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
+2
-2
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+28
-30
No files found.
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
0 → 100644
View file @
3bbe11cc
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
silu_mul_fp8_quant_deep_gemm
,
)
from
vllm.platforms
import
current_platform
def
benchmark
(
E
,
T
,
H
,
G
=
128
,
runs
=
50
):
current_platform
.
seed_everything
(
42
)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
randint
(
T
//
2
,
T
,
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
10
):
silu_mul_fp8_quant_deep_gemm
(
y
,
tokens_per_expert
,
group_size
=
G
)
torch
.
cuda
.
synchronize
()
# Benchmark
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
runs
):
silu_mul_fp8_quant_deep_gemm
(
y
,
tokens_per_expert
,
group_size
=
G
)
torch
.
cuda
.
synchronize
()
avg_time
=
(
time
.
perf_counter
()
-
start
)
/
runs
*
1000
# Calculate actual work done (only count valid tokens)
actual_tokens
=
tokens_per_expert
.
sum
().
item
()
actual_elements
=
actual_tokens
*
H
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
ops_per_element
=
8
total_ops
=
actual_elements
*
ops_per_element
gflops
=
total_ops
/
(
avg_time
/
1000
)
/
1e9
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
input_bytes
=
actual_tokens
*
2
*
H
*
2
# 2*H bfloat16 inputs
output_bytes
=
actual_tokens
*
H
*
1
# H fp8 outputs
scale_bytes
=
actual_tokens
*
(
H
//
G
)
*
4
# scales in float32
total_bytes
=
input_bytes
+
output_bytes
+
scale_bytes
memory_bw
=
total_bytes
/
(
avg_time
/
1000
)
/
1e9
return
avg_time
,
gflops
,
memory_bw
configs
=
[
(
8
,
32
,
1024
),
(
16
,
64
,
2048
),
(
32
,
128
,
4096
),
# DeepSeekV3 Configs
(
256
,
16
,
7168
),
(
256
,
32
,
7168
),
(
256
,
64
,
7168
),
(
256
,
128
,
7168
),
(
256
,
256
,
7168
),
(
256
,
512
,
7168
),
(
256
,
1024
,
7168
),
]
print
(
f
"GPU:
{
torch
.
cuda
.
get_device_name
()
}
"
)
print
(
f
"
{
'Config'
:
<
20
}
{
'Time(ms)'
:
<
10
}
{
'GFLOPS'
:
<
10
}
{
'GB/s'
:
<
10
}
"
)
print
(
"-"
*
50
)
for
E
,
T
,
H
in
configs
:
try
:
time_ms
,
gflops
,
gbps
=
benchmark
(
E
,
T
,
H
)
print
(
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
{
time_ms
:
8.3
f
}
{
gflops
:
8.1
f
}
{
gbps
:
8.1
f
}
"
)
except
Exception
:
print
(
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
FAILED"
)
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
View file @
3bbe11cc
...
...
@@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform
.
seed_everything
(
seed
)
# Input tensor of shape (E, T, 2*H)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
float
32
,
device
=
"cuda"
)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
b
float
16
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
randint
(
low
=
0
,
high
=
T
,
...
...
@@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
y_se
=
y_s
[
e
]
y_qe
=
y_q
[
e
]
torch
.
testing
.
assert_close
(
y_se
[:
nt
],
ref_s
[:
nt
])
torch
.
testing
.
assert_close
(
y_se
[:
nt
],
ref_s
[:
nt
]
,
atol
=
1e-4
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
y_qe
[:
nt
].
to
(
torch
.
float32
),
ref_q
[:
nt
].
to
(
torch
.
float32
),
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
3bbe11cc
...
...
@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
# number of valid tokens for this expert
n_tokens
=
tl
.
load
(
counts_ptr
+
e
*
stride_counts_e
).
to
(
tl
.
int64
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
cols
=
cols
.
to
(
tl
.
int64
)
mask_h
=
cols
<
BLOCK
cols
=
tl
.
arange
(
0
,
BLOCK
).
to
(
tl
.
int64
)
mask
=
cols
<
BLOCK
base_input_offset
=
e
*
stride_i_e
+
g
*
GROUP_SIZE
*
stride_i_h
base_gate_offset
=
base_input_offset
+
cols
*
stride_i_h
base_up_offset
=
base_input_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
base_yq_offset
=
(
e
*
stride_yq_e
+
g
*
GROUP_SIZE
*
stride_yq_h
+
cols
*
stride_yq_h
)
base_ys_offset
=
e
*
stride_ys_e
+
g
*
stride_ys_g
for
t
in
tl
.
range
(
0
,
n_tokens
,
num_stages
=
NUM_STAGES
):
base_i_offset
=
(
e
*
stride_i_e
+
t
*
stride_i_t
+
g
*
GROUP_SIZE
*
stride_i_h
)
base_yq_offset
=
(
e
*
stride_yq_e
+
t
*
stride_yq_t
+
g
*
GROUP_SIZE
*
stride_yq_h
)
base_ys_offset
=
e
*
stride_ys_e
+
t
*
stride_ys_t
+
g
*
stride_ys_g
mask
=
mask_h
x
=
tl
.
load
(
input_ptr
+
base_i_offset
+
cols
*
stride_i_h
,
gate
=
tl
.
load
(
input_ptr
+
base_gate_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y2
=
tl
.
load
(
input_ptr
+
base_i_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
,
up
=
tl
.
load
(
input_ptr
+
base_up_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
)
.
to
(
tl
.
float32
)
other
=
0.0
)
x
=
x
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
x
)))
y
=
x
*
y2
gate
=
gate
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
gate
)))
y
=
gate
*
up
y_s
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
/
fp8_max
if
use_ue8m0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
y_s
)))
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
scale_raw
=
_absmax
/
fp8_max
y_s
=
tl
.
math
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
scale_raw
)))
if
use_ue8m0
else
scale_raw
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
cols
*
stride_yq_
h
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
,
y_s
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
t
*
stride_yq_
t
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
float32
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`
is the
FP8 tensor
of
shape
`
(E, T, H)
`
, same layout as
`
y[..., :H]
`.
* `y_s`
has
shape
`
(E, T, H // group_size)
` and
strides
`
(T*G, 1, T)
`
* `y_q`
:
FP8 tensor
,
shape (E, T, H), same layout as y[..., :H]
* `y_s`
: FP32 tensor,
shape (E, T, H // group_size)
,
strides (T*G, 1, T)
"""
assert
y
.
ndim
==
3
,
"y must be (E, T, 2*H)"
E
,
T
,
H2
=
y
.
shape
...
...
@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
#
s
tatic grid over experts and H-groups.
#
S
tatic grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,
)
...
...
@@ -178,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
fp8_max
,
is_blackwell_deep_gemm_e8m0_used
(),
BLOCK
=
group_size
,
NUM_STAGES
=
8
,
NUM_STAGES
=
4
,
num_warps
=
1
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment