Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98229db2
Unverified
Commit
98229db2
authored
Sep 13, 2025
by
Elvir Crnčević
Committed by
GitHub
Sep 13, 2025
Browse files
[Kernels][DP/EP] Optimize Silu Kernel for R1 (#24054)
Signed-off-by:
elvircrn
<
elvircrn@gmail.com
>
parent
dbeee384
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1272 additions
and
131 deletions
+1272
-131
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+636
-38
csrc/ops.h
csrc/ops.h
+6
-0
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+465
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
+71
-33
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+87
-60
No files found.
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
View file @
98229db2
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tim
e
from
collections.abc
import
Callabl
e
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
silu_mul_fp8_quant_deep_gemm
,
silu_mul_fp8_quant_deep_gemm
_cuda
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
def
benchmark
(
E
,
T
,
H
,
G
=
128
,
runs
=
50
):
current_platform
.
seed_everything
(
42
)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
randint
(
T
//
2
,
T
,
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
@
triton
.
jit
def
_silu_mul_fp8_quant_deep_gemm
(
# Pointers ------------------------------------------------------------
input_ptr
,
# 16-bit activations (E, T, 2*H)
y_q_ptr
,
# fp8 quantized activations (E, T, H)
y_s_ptr
,
# 16-bit scales (E, T, G)
counts_ptr
,
# int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H
:
tl
.
constexpr
,
# hidden dimension (per output)
GROUP_SIZE
:
tl
.
constexpr
,
# elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e
,
stride_i_t
,
stride_i_h
,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e
,
stride_yq_t
,
stride_yq_h
,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
# Stride for counts (elements)
stride_counts_e
,
# Numeric params ------------------------------------------------------
eps
:
tl
.
constexpr
,
fp8_min
:
tl
.
constexpr
,
fp8_max
:
tl
.
constexpr
,
use_ue8m0
:
tl
.
constexpr
,
# Meta ---------------------------------------------------------------
BLOCK
:
tl
.
constexpr
,
NUM_STAGES
:
tl
.
constexpr
,
):
G
=
H
//
GROUP_SIZE
# map program id -> (e, g)
pid
=
tl
.
program_id
(
0
)
e
=
pid
//
G
g
=
pid
%
G
e
=
e
.
to
(
tl
.
int64
)
g
=
g
.
to
(
tl
.
int64
)
# number of valid tokens for this expert
n_tokens
=
tl
.
load
(
counts_ptr
+
e
*
stride_counts_e
).
to
(
tl
.
int64
)
cols
=
tl
.
arange
(
0
,
BLOCK
).
to
(
tl
.
int64
)
mask
=
cols
<
BLOCK
base_input_offset
=
e
*
stride_i_e
+
g
*
GROUP_SIZE
*
stride_i_h
base_gate_offset
=
base_input_offset
+
cols
*
stride_i_h
base_up_offset
=
base_input_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
base_yq_offset
=
e
*
stride_yq_e
+
g
*
GROUP_SIZE
*
stride_yq_h
+
cols
*
stride_yq_h
base_ys_offset
=
e
*
stride_ys_e
+
g
*
stride_ys_g
for
t
in
tl
.
range
(
0
,
n_tokens
,
num_stages
=
NUM_STAGES
):
gate
=
tl
.
load
(
input_ptr
+
base_gate_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
input_ptr
+
base_up_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
)
gate
=
gate
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
gate
)))
y
=
gate
*
up
y_s
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
/
fp8_max
if
use_ue8m0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
y_s
)))
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
t
*
stride_yq_t
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm_triton
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
num_parallel_tokens
,
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
"""
assert
y
.
ndim
==
3
,
"y must be (E, T, 2*H)"
E
,
T
,
H2
=
y
.
shape
assert
H2
%
2
==
0
,
"last dim of y must be even (2*H)"
H
=
H2
//
2
G
=
(
H
+
group_size
-
1
)
//
group_size
assert
H
%
group_size
==
0
,
"H must be divisible by group_size"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
,
(
"tokens_per_expert must be shape (E,)"
)
tokens_per_expert
=
tokens_per_expert
.
to
(
device
=
y
.
device
,
dtype
=
torch
.
int32
)
# allocate outputs
fp8_dtype
=
torch
.
float8_e4m3fn
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_g
=
T
y_s
=
torch
.
empty_strided
(
(
E
,
T
,
G
),
(
stride_ys_e
,
stride_ys_t
,
stride_ys_g
),
dtype
=
torch
.
float32
,
device
=
y
.
device
,
)
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,)
f_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
f_info
.
max
fp8_min
=
f_info
.
min
_silu_mul_fp8_quant_deep_gemm
[
grid
](
y
,
y_q
,
y_s
,
tokens_per_expert
,
H
,
group_size
,
stride_i_e
,
stride_i_t
,
stride_i_h
,
stride_yq_e
,
stride_yq_t
,
stride_yq_h
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_cnt_e
,
eps
,
fp8_min
,
fp8_max
,
is_deep_gemm_e8m0_used
(),
BLOCK
=
group_size
,
NUM_STAGES
=
4
,
num_warps
=
1
,
)
return
y_q
,
y_s
# Parse generation strategies
strategies
=
[
"uniform"
,
"max_t"
,
"first_t"
]
def
benchmark
(
kernel
:
Callable
,
E
:
int
,
T
:
int
,
H
:
int
,
total_tokens
:
int
,
num_parallel_tokens
:
int
=
64
,
G
:
int
=
128
,
runs
:
int
=
200
,
num_warmups
:
int
=
20
,
gen_strategy
:
str
=
"default"
,
iterations_per_run
:
int
=
20
,
):
def
generate_data
(
seed_offset
=
0
):
"""Generate input data with given seed offset"""
current_platform
.
seed_everything
(
42
+
seed_offset
)
y
=
torch
.
rand
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
if
gen_strategy
==
"uniform"
:
r
=
torch
.
rand
(
size
=
(
E
,),
device
=
"cuda"
)
r
/=
r
.
sum
()
r
*=
total_tokens
tokens_per_expert
=
r
.
int
()
tokens_per_expert
=
torch
.
minimum
(
tokens_per_expert
,
torch
.
ones
((
E
,),
device
=
r
.
device
,
dtype
=
torch
.
int
)
*
T
,
)
elif
gen_strategy
==
"max_t"
:
tokens_per_expert
=
torch
.
empty
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
.
fill_
(
total_tokens
/
E
)
elif
gen_strategy
==
"first_t"
:
tokens_per_expert
=
torch
.
zeros
(
size
=
(
E
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
tokens_per_expert
[
0
]
=
min
(
T
,
total_tokens
)
else
:
raise
ValueError
(
f
"Unknown generation strategy:
{
gen_strategy
}
"
)
return
y
,
tokens_per_expert
dataset_count
=
4
# Pre-generate different input matrices for each iteration to avoid cache effects
data_sets
=
[
generate_data
(
i
)
for
i
in
range
(
dataset_count
)]
# Warmup
for
_
in
range
(
10
):
silu_mul_fp8_quant_deep_gemm
(
y
,
tokens_per_expert
,
group_size
=
G
)
y
,
tokens_per_expert
=
data_sets
[
0
]
for
_
in
range
(
num_warmups
):
kernel
(
y
,
tokens_per_expert
,
num_parallel_tokens
=
num_parallel_tokens
,
group_size
=
G
)
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# Benchmark
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
latencies
:
list
[
float
]
=
[]
for
_
in
range
(
runs
):
silu_mul_fp8_quant_deep_gemm
(
y
,
tokens_per_expert
,
group_size
=
G
)
torch
.
cuda
.
synchronize
()
avg_time
=
(
time
.
perf_counter
()
-
start
)
/
runs
*
1000
start_event
.
record
()
for
i
in
range
(
iterations_per_run
):
y
,
tokens_per_expert
=
data_sets
[
i
%
dataset_count
]
kernel
(
y
,
tokens_per_expert
,
num_parallel_tokens
=
num_parallel_tokens
,
group_size
=
G
,
)
end_event
.
record
()
end_event
.
synchronize
()
total_time_ms
=
start_event
.
elapsed_time
(
end_event
)
per_iter_time_ms
=
total_time_ms
/
iterations_per_run
latencies
.
append
(
per_iter_time_ms
)
# Calculate actual work done (only count valid tokens)
# Use median instead of average for better outlier handling
median_time_ms
=
np
.
median
(
latencies
)
median_time_s
=
median_time_ms
/
1000
# Calculate actual work done (using first dataset for consistency)
_
,
tokens_per_expert
=
data_sets
[
0
]
actual_tokens
=
tokens_per_expert
.
sum
().
item
()
actual_elements
=
actual_tokens
*
H
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
ops_per_element
=
8
total_ops
=
actual_elements
*
ops_per_element
gflops
=
total_ops
/
(
avg_time
/
1000
)
/
1e9
gflops
=
total_ops
/
median_time_s
/
1e9
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
input_bytes
=
actual_tokens
*
2
*
H
*
2
# 2*H bfloat16 inputs
output_bytes
=
actual_tokens
*
H
*
1
# H fp8 outputs
scale_bytes
=
actual_tokens
*
(
H
//
G
)
*
4
# scales in float32
total_bytes
=
input_bytes
+
output_bytes
+
scale_bytes
memory_bw
=
total_bytes
/
(
avg_time
/
1000
)
/
1e9
memory_bw
=
total_bytes
/
median_time_s
/
1e9
return
avg_time
,
gflops
,
memory_bw
HOPPER_BANDWIDTH_TBPS
=
3.35
return
(
median_time_ms
,
gflops
,
memory_bw
,
(
memory_bw
/
(
HOPPER_BANDWIDTH_TBPS
*
1024
))
*
100
,
)
def
create_comparison_plot
(
ratio
,
cuda_times
,
baseline_times
,
config_labels
,
strategy_name
,
id
):
"""Create a comparison plot for a specific generation strategy"""
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
16
,
6
))
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.35
# Execution Time plot (lower is better)
ax
.
bar
(
x
-
width
/
2
,
cuda_times
,
width
,
label
=
"CUDA Kernel"
,
alpha
=
0.8
,
color
=
"blue"
)
ax
.
bar
(
x
+
width
/
2
,
baseline_times
,
width
,
label
=
"Baseline"
,
alpha
=
0.8
,
color
=
"orange"
,
)
# Add speedup labels over each bar pair
for
i
in
range
(
len
(
x
)):
speedup
=
ratio
[
i
]
max_height
=
max
(
cuda_times
[
i
],
baseline_times
[
i
])
ax
.
text
(
x
[
i
],
max_height
+
max_height
*
0.02
,
f
"
{
speedup
:.
2
f
}
x"
,
ha
=
"center"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontsize
=
9
,
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_ylabel
(
"% Utilization"
)
ax
.
set_title
(
f
"Memory Bandwidth Utilization (%) -
{
strategy_name
}
\n
(Higher is Better)"
)
ax
.
set_xticks
(
x
)
ax
.
set_xticklabels
(
config_labels
,
rotation
=
45
,
ha
=
"right"
)
ax
.
legend
()
ax
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
return
fig
,
ax
def
create_combined_plot
(
all_results
):
"""Create a combined plot with all strategies in one PNG"""
num_strategies
=
len
(
all_results
)
fig
,
axes
=
plt
.
subplots
(
num_strategies
,
1
,
figsize
=
(
20
,
6
*
num_strategies
))
if
num_strategies
==
1
:
axes
=
[
axes
]
for
idx
,
(
strategy_name
,
ratio
,
cuda_times
,
baseline_times
,
config_labels
,
)
in
enumerate
(
all_results
):
ax
=
axes
[
idx
]
# Configure x-axis positions
x
=
np
.
arange
(
len
(
config_labels
))
width
=
0.35
# Execution Time plot (lower is better)
ax
.
bar
(
x
-
width
/
2
,
cuda_times
,
width
,
label
=
"CUDA Kernel"
,
alpha
=
0.8
,
color
=
"blue"
,
)
ax
.
bar
(
x
+
width
/
2
,
baseline_times
,
width
,
label
=
"Baseline"
,
alpha
=
0.8
,
color
=
"orange"
,
)
# Add speedup labels over each bar pair
for
i
in
range
(
len
(
x
)):
speedup
=
ratio
[
i
]
max_height
=
max
(
cuda_times
[
i
],
baseline_times
[
i
])
ax
.
text
(
x
[
i
],
max_height
+
max_height
*
0.02
,
f
"
{
speedup
:.
2
f
}
x"
,
ha
=
"center"
,
va
=
"bottom"
,
fontweight
=
"bold"
,
fontsize
=
9
,
)
ax
.
set_xlabel
(
"Configuration"
)
ax
.
set_ylabel
(
"% Utilization"
)
ax
.
set_title
(
f
"Memory Bandwidth Utilization (%) -
{
strategy_name
}
\n
(Higher is Better)"
)
ax
.
set_xticks
(
x
)
ax
.
set_xticklabels
(
config_labels
,
rotation
=
45
,
ha
=
"right"
)
ax
.
legend
()
ax
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
filename
=
"../../silu_bench/silu_benchmark_combined.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
return
filename
outer_dim
=
7168
configs
=
[
(
8
,
32
,
1024
),
(
16
,
64
,
2048
),
(
32
,
128
,
4096
),
# DeepSeekV3 Configs
(
256
,
16
,
7168
),
(
256
,
32
,
7168
),
(
256
,
64
,
7168
),
(
256
,
128
,
7168
),
(
256
,
256
,
7168
),
(
256
,
512
,
7168
),
(
8
,
1024
,
7168
),
# DeepSeekV3 Configs
(
32
,
1024
,
7168
),
# DeepSeekV3 Configs
(
256
,
1024
,
7168
),
]
runs
=
100
num_warmups
=
20
strategy_descriptions
=
{
"uniform"
:
"Uniform Random"
,
"max_t"
:
"Even Assignment"
,
"first_t"
:
"experts[0] = T, experts[1:] = 0"
,
}
print
(
f
"GPU:
{
torch
.
cuda
.
get_device_name
()
}
"
)
print
(
f
"
{
'Config'
:
<
20
}
{
'Time(ms)'
:
<
10
}
{
'GFLOPS'
:
<
10
}
{
'GB/s'
:
<
10
}
"
)
print
(
"-"
*
50
)
for
E
,
T
,
H
in
configs
:
try
:
time_ms
,
gflops
,
gbps
=
benchmark
(
E
,
T
,
H
)
print
(
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
{
time_ms
:
8.3
f
}
{
gflops
:
8.1
f
}
{
gbps
:
8.1
f
}
"
)
except
Exception
:
print
(
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
FAILED"
)
print
(
f
"Testing strategies:
{
', '
.
join
(
strategies
)
}
"
)
print
(
f
"Configurations:
{
len
(
configs
)
}
configs"
)
all_results
=
[]
# Run benchmarks for each strategy
for
id
,
strategy
in
enumerate
(
strategies
):
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"Testing strategy:
{
strategy_descriptions
[
strategy
]
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
# Collect benchmark data for both algorithms
config_labels
=
[]
config_x_axis
=
[]
all_cuda_results
=
[]
all_baseline_results
=
[]
all_ratios
=
[]
for
E
,
T
,
H
in
configs
:
total_tokens_config
=
[
8
*
E
,
16
*
E
,
32
*
E
,
64
*
E
,
128
*
E
,
256
*
E
]
config_x_axis
.
append
(
total_tokens_config
)
cuda_results
=
[]
baseline_results
=
[]
ratios
=
[]
for
total_tokens
in
total_tokens_config
:
config_label
=
f
"E=
{
E
}
,T=
{
T
}
,H=
{
H
}
,TT=
{
total_tokens
}
"
config_labels
.
append
(
config_label
)
# CUDA kernel results
time_ms_cuda
,
gflops
,
gbps
,
perc
=
benchmark
(
silu_mul_fp8_quant_deep_gemm_cuda
,
E
,
T
,
H
,
total_tokens
,
runs
=
runs
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
)
cuda_results
.
append
((
time_ms_cuda
,
gflops
,
gbps
,
perc
))
# Baseline results
time_ms_triton
,
gflops
,
gbps
,
perc
=
benchmark
(
silu_mul_fp8_quant_deep_gemm_triton
,
E
,
T
,
H
,
total_tokens
,
runs
=
runs
,
num_warmups
=
num_warmups
,
gen_strategy
=
strategy
,
)
baseline_results
.
append
((
time_ms_triton
,
gflops
,
gbps
,
perc
))
ratios
.
append
(
time_ms_triton
/
time_ms_cuda
)
print
(
f
"Completed:
{
config_label
}
"
)
all_cuda_results
.
append
(
cuda_results
)
all_baseline_results
.
append
(
baseline_results
)
all_ratios
.
append
(
ratios
)
# Store results for combined plotting
all_results
.
append
(
(
strategy_descriptions
[
strategy
],
all_ratios
,
all_cuda_results
,
all_baseline_results
,
config_labels
,
config_x_axis
,
)
)
# Print summary table for this strategy
print
(
f
"
\n
Summary Table -
{
strategy_descriptions
[
strategy
]
}
:"
)
print
(
f
"
{
'Config'
:
<
20
}
{
'CUDA Time(ms)'
:
<
12
}
{
'Base Time(ms)'
:
<
12
}
{
'Speedup'
:
<
8
}
"
)
print
(
"-"
*
60
)
for
i
,
(
E
,
T
,
H
)
in
enumerate
(
configs
):
speedup
=
baseline_results
[
i
][
0
]
/
cuda_results
[
i
][
0
]
config_label
=
f
"E=
{
E
:
3
d
}
,T=
{
T
:
4
d
}
,H=
{
H
:
4
d
}
"
print
(
f
"
{
config_label
:
<
20
}
{
cuda_results
[
i
][
0
]:
8.5
f
}
"
f
"
{
baseline_results
[
i
][
0
]:
8.5
f
}
{
speedup
:
6.2
f
}
x"
)
def
create_total_tokens_plot
(
all_results
):
num_strategies
=
len
(
all_results
)
num_configs
=
len
(
configs
)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig
,
axs
=
plt
.
subplots
(
num_strategies
,
num_configs
*
2
,
figsize
=
(
28
,
6
*
num_strategies
)
)
# Add main title to the entire figure
fig
.
suptitle
(
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)"
,
fontsize
=
16
,
fontweight
=
"bold"
,
y
=
0.98
,
)
# Handle single strategy case
if
num_strategies
==
1
:
axs
=
axs
.
reshape
(
1
,
-
1
)
# Handle single config case
if
num_configs
==
1
:
axs
=
axs
.
reshape
(
-
1
,
2
)
for
strategy_idx
,
result
in
enumerate
(
all_results
):
(
strategy_name
,
all_ratios
,
all_cuda_results
,
all_baseline_results
,
config_labels
,
config_x_axis
,
)
=
result
for
config_idx
in
range
(
num_configs
):
# Speedup plot (left column)
ax_speedup
=
axs
[
strategy_idx
,
config_idx
*
2
]
# Bandwidth plot (right column)
ax_bandwidth
=
axs
[
strategy_idx
,
config_idx
*
2
+
1
]
E
,
T
,
H
=
configs
[
config_idx
]
ratios
=
all_ratios
[
config_idx
]
total_tokens_values
=
config_x_axis
[
config_idx
]
# Extract CUDA and Triton bandwidth percentages
cuda_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_cuda_results
[
config_idx
]
]
triton_bandwidth_percentages
=
[
result
[
3
]
for
result
in
all_baseline_results
[
config_idx
]
]
# Plot speedup ratios vs total tokens (left plot)
ax_speedup
.
plot
(
total_tokens_values
,
ratios
,
"bo-"
,
linewidth
=
3
,
markersize
=
8
)
ax_speedup
.
set_title
(
f
"
{
strategy_name
}
\n
Speedup (CUDA/Triton)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
fontsize
=
12
,
fontweight
=
"bold"
,
)
ax_speedup
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
set_ylabel
(
"Speedup Ratio"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_speedup
.
grid
(
True
,
alpha
=
0.3
)
ax_bandwidth
.
plot
(
total_tokens_values
,
cuda_bandwidth_percentages
,
"ro-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"CUDA"
,
)
ax_bandwidth
.
plot
(
total_tokens_values
,
triton_bandwidth_percentages
,
"go-"
,
linewidth
=
3
,
markersize
=
8
,
label
=
"Triton"
,
)
ax_bandwidth
.
set_title
(
f
"
{
strategy_name
}
\n
Bandwidth Utilization (Hopper)
\n
E=
{
E
}
, T=
{
T
}
, H=
{
H
}
"
,
fontsize
=
12
,
fontweight
=
"bold"
,
)
ax_bandwidth
.
set_xlabel
(
"Total Tokens"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_bandwidth
.
set_ylabel
(
"% of Peak Bandwidth"
,
fontweight
=
"bold"
,
fontsize
=
11
)
ax_bandwidth
.
legend
(
prop
=
{
"weight"
:
"bold"
})
ax_bandwidth
.
grid
(
True
,
alpha
=
0.3
)
# Format x-axis labels for both plots
for
ax
in
[
ax_speedup
,
ax_bandwidth
]:
ax
.
set_xticks
(
total_tokens_values
)
ax
.
set_xticklabels
(
[
f
"
{
tt
//
1000
}
K"
if
tt
>=
1000
else
str
(
tt
)
for
tt
in
total_tokens_values
],
fontweight
=
"bold"
,
)
# Make tick labels bold
for
label
in
ax
.
get_xticklabels
()
+
ax
.
get_yticklabels
():
label
.
set_fontweight
(
"bold"
)
# Add value labels on speedup points
for
x
,
y
in
zip
(
total_tokens_values
,
ratios
):
ax_speedup
.
annotate
(
f
"
{
y
:.
2
f
}
x"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
12
),
ha
=
"center"
,
fontsize
=
10
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.3"
,
facecolor
=
"white"
,
alpha
=
0.7
),
)
# Add value labels on CUDA bandwidth points
for
x
,
y
in
zip
(
total_tokens_values
,
cuda_bandwidth_percentages
):
ax_bandwidth
.
annotate
(
f
"
{
y
:.
1
f
}
%"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
12
),
ha
=
"center"
,
fontsize
=
9
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.2"
,
facecolor
=
"red"
,
alpha
=
0.3
),
)
# Add value labels on Triton bandwidth points
for
x
,
y
in
zip
(
total_tokens_values
,
triton_bandwidth_percentages
):
ax_bandwidth
.
annotate
(
f
"
{
y
:.
1
f
}
%"
,
(
x
,
y
),
textcoords
=
"offset points"
,
xytext
=
(
0
,
-
15
),
ha
=
"center"
,
fontsize
=
9
,
fontweight
=
"bold"
,
bbox
=
dict
(
boxstyle
=
"round,pad=0.2"
,
facecolor
=
"green"
,
alpha
=
0.3
),
)
plt
.
tight_layout
()
plt
.
subplots_adjust
(
top
=
0.93
)
# Make room for main title
filename
=
"silu_benchmark_total_tokens.png"
plt
.
savefig
(
filename
,
dpi
=
300
,
bbox_inches
=
"tight"
)
plt
.
show
()
return
filename
# Create combined plot with all strategies
combined_plot_filename
=
create_total_tokens_plot
(
all_results
)
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
"Benchmark Complete!"
)
print
(
f
"Generated combined plot:
{
combined_plot_filename
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
csrc/ops.h
View file @
98229db2
...
...
@@ -133,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input_global_scale
);
#endif
void
silu_mul_fp8_quant_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/quantization/activation_kernels.cu
View file @
98229db2
...
...
@@ -9,6 +9,26 @@
#include "quantization/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
typedef
__hip_bfloat162
__nv_bfloat162
;
typedef
__hip_bfloat16
__nv_bfloat16
;
typedef
__hip_bfloat16_raw
__nv_bfloat16_raw
;
typedef
__hip_fp8_e4m3
__nv_fp8_e4m3
;
typedef
__hip_fp8x4_e4m3
__nv_fp8x4_e4m3
;
#endif
#include "core/registration.h"
namespace
vllm
{
template
<
typename
T
>
...
...
@@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel(
}
}
}
__device__
__forceinline__
float
silu
(
float
x
)
{
return
(
__fdividef
(
x
,
(
1.
f
+
expf
(
-
x
))));
}
__device__
__forceinline__
float2
silu2
(
float2
x
)
{
return
make_float2
(
silu
(
x
.
x
),
silu
(
x
.
y
));
}
#ifndef USE_ROCM
__device__
__forceinline__
float
warp_max
(
float
v
)
{
static
constexpr
unsigned
FULL_MASK
=
0xffffffffu
;
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
*=
2
)
{
v
=
fmaxf
(
v
,
__shfl_xor_sync
(
FULL_MASK
,
v
,
offset
));
}
return
v
;
}
__device__
__forceinline__
__nv_bfloat16
warp_max
(
__nv_bfloat16
v
)
{
static
constexpr
unsigned
FULL_MASK
=
0xffffffffu
;
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
*=
2
)
{
v
=
__hmax
(
v
,
__shfl_xor_sync
(
FULL_MASK
,
v
,
offset
));
}
return
v
;
}
#endif
template
<
typename
T
,
typename
U
>
__device__
__forceinline__
void
cp_async4
(
T
*
_smem_ptr
,
const
U
*
_glob_ptr
)
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
auto
smem_ptr
=
reinterpret_cast
<
void
*>
(
_smem_ptr
);
auto
glob_ptr
=
reinterpret_cast
<
const
void
*>
(
_glob_ptr
);
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
#else
_smem_ptr
[
0
]
=
_glob_ptr
[
0
];
#endif
}
__device__
__forceinline__
void
cp_async_fence
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
#else
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cp_async_wait
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
N
));
#else
#endif
}
template
<
>
__device__
__forceinline__
void
cp_async_wait
<
0
>
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_all;
\n
"
::
);
#else
#endif
}
__device__
__forceinline__
float
clip
(
float
v
,
float
mmin
,
float
mmax
)
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return
fminf
(
mmax
,
fmaxf
(
v
,
mmin
));
#else
#endif
}
__device__
__forceinline__
__nv_bfloat16
clip
(
__nv_bfloat16
v
,
__nv_bfloat16
mmin
,
__nv_bfloat16
mmax
)
{
return
__hmin
(
mmax
,
__hmax
(
v
,
mmin
));
}
__device__
__forceinline__
__nv_bfloat162
clip
(
__nv_bfloat162
v
,
__nv_bfloat162
mmin
,
__nv_bfloat162
mmax
)
{
return
__hmin2
(
mmax
,
__hmax2
(
v
,
mmin
));
}
// We use the following values for fp8 min/max:
// __nv_fp8_e4m3 = (-448, +448)
// __nv_fp8_e4m3uz = (-240.0, +240.0)
// It is currently assumed that only
template
<
class
T
>
constexpr
__nv_bfloat16
get_fp8_max
()
{
static_assert
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
);
if
constexpr
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
)
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
17376
});
}
else
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
17264
});
}
}
template
<
class
T
>
constexpr
__nv_bfloat16
get_fp8_min
()
{
static_assert
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
);
if
constexpr
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
)
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
50144
});
}
else
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
50032
});
}
}
#ifndef USE_ROCM
template
<
typename
fp8_type
,
int32_t
NUM_WARPS
,
typename
Idx_t
,
int
NUM_PARALLEL_TOKENS
,
bool
USE_UE8M0
,
int
GROUP_SIZE
=
128
,
int
NUM_STAGES
=
3
>
__global__
void
silu_mul_fp8_quant_deep_gemm_kernel
(
const
__nv_bfloat16
*
__restrict__
_input
,
fp8_type
*
__restrict__
_y_q
,
float
*
__restrict__
_y_s
,
const
int32_t
*
__restrict__
counts
,
// sizes
int
H
,
int
G
,
// strides (in elements)
Idx_t
stride_i_e
,
Idx_t
stride_i_t
,
Idx_t
stride_i_h
,
Idx_t
stride_yq_e
,
Idx_t
stride_yq_t
,
Idx_t
stride_yq_h
,
Idx_t
stride_ys_e
,
Idx_t
stride_ys_t
,
Idx_t
stride_ys_g
,
Idx_t
stride_counts_e
)
{
static
constexpr
__nv_bfloat16
fp8_min
=
get_fp8_min
<
fp8_type
>
();
static
constexpr
__nv_bfloat16
fp8_max
=
get_fp8_max
<
fp8_type
>
();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static
constexpr
__nv_bfloat16
EPS
=
(
__nv_bfloat16_raw
{.
x
=
11996
});
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static
constexpr
int32_t
BFLOAT16_PER_GROUP
=
8
;
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static
constexpr
int32_t
S_NUM_128
=
2u
*
(
GROUP_SIZE
/
BFLOAT16_PER_GROUP
)
*
NUM_WARPS
*
NUM_STAGES
;
static
constexpr
auto
THREAD_COUNT
=
NUM_WARPS
*
WARP_SIZE
;
static
constexpr
int
HALF_THREAD_COUNT
=
THREAD_COUNT
/
2
;
static
constexpr
int32_t
S_NUM_64
=
S_NUM_128
*
2
;
__shared__
__int128_t
__align__
(
16
)
s_buff_128
[
S_NUM_128
];
const
int32_t
tid
=
threadIdx
.
x
;
const
int32_t
warp_id
=
tid
/
WARP_SIZE
;
const
int32_t
lane_id
=
tid
%
WARP_SIZE
;
auto
s_buff_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
s_buff_128
);
// block handles one (expert e, group g)
int32_t
pid
=
blockIdx
.
x
;
int32_t
e
=
pid
/
G
;
int32_t
g
=
pid
%
G
;
const
int32_t
n_tokens
=
counts
[
e
*
stride_counts_e
];
if
(
!
n_tokens
)
{
return
;
// Exit ASAP.
}
const
Idx_t
stride_i_t_128
=
stride_i_t
/
8u
;
int32_t
n_tokens_lower
,
n_tokens_upper
;
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if
(
n_tokens
<
NUM_PARALLEL_TOKENS
&&
blockIdx
.
y
<
n_tokens
)
{
// Specialize this, but can be likely fused.
if
(
blockIdx
.
y
>=
NUM_PARALLEL_TOKENS
)
{
return
;
}
n_tokens_lower
=
blockIdx
.
y
;
n_tokens_upper
=
blockIdx
.
y
+
1
;
}
else
{
auto
chunk_size
=
n_tokens
/
NUM_PARALLEL_TOKENS
;
auto
residual
=
n_tokens
-
chunk_size
*
NUM_PARALLEL_TOKENS
;
auto
calc_id
=
[
&
](
int32_t
id
)
{
if
(
id
<
residual
)
{
return
min
(
n_tokens
,
id
*
(
chunk_size
+
1
));
}
else
{
return
min
(
n_tokens
,
id
*
chunk_size
+
residual
);
}
};
n_tokens_lower
=
calc_id
(
blockIdx
.
y
);
n_tokens_upper
=
calc_id
(
blockIdx
.
y
+
1
);
}
if
(
n_tokens_lower
>=
n_tokens_upper
)
{
return
;
}
// We do calculations here, using constexpr wherever possible.
const
Idx_t
base_i
=
e
*
stride_i_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_i_h
;
const
Idx_t
base_ys
=
e
*
stride_ys_e
+
NUM_WARPS
*
g
*
stride_ys_g
;
const
Idx_t
base_yq
=
e
*
stride_yq_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_yq_h
;
Idx_t
gate_off_128
=
(
base_i
/
static_cast
<
Idx_t
>
(
8u
));
auto
input_128_ptr
=
reinterpret_cast
<
const
__int128_t
*>
(
_input
);
auto
gate_128_ptr
=
input_128_ptr
+
gate_off_128
+
(
tid
%
HALF_THREAD_COUNT
)
+
stride_i_t_128
*
n_tokens_lower
;
auto
up_128_ptr
=
gate_128_ptr
+
(
H
*
stride_i_h
)
/
8u
;
auto
y_s_ptr
=
_y_s
+
base_ys
+
warp_id
*
stride_ys_g
+
n_tokens_lower
*
stride_ys_t
;
auto
y_q_ptr
=
_y_q
+
base_yq
+
warp_id
*
GROUP_SIZE
+
stride_yq_t
*
n_tokens_lower
+
4
*
lane_id
;
int32_t
t_load
=
n_tokens_lower
,
load_stage_id
=
0
;
auto
s_buff_gate_load_128
=
s_buff_128
+
(
tid
%
HALF_THREAD_COUNT
);
auto
s_buff_up_load_128
=
s_buff_gate_load_128
+
S_NUM_128
/
2u
;
int32_t
stage_offset
{};
static
constexpr
int32_t
LOAD_STAGE_SIZE
=
(
NUM_WARPS
*
WARP_SIZE
/
2
);
static
constexpr
int32_t
LOAD_STAGE_MOD
=
NUM_STAGES
*
(
NUM_WARPS
*
WARP_SIZE
/
2
);
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
auto
load_and_advance_y_pred
=
[
&
]
{
if
(
t_load
<
n_tokens_upper
)
{
auto
s_gate_stage_128_staged_ptr
=
s_buff_gate_load_128
+
stage_offset
;
auto
s_up_stage_128_staged_ptr
=
s_buff_up_load_128
+
stage_offset
;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset
+=
LOAD_STAGE_SIZE
;
stage_offset
%=
LOAD_STAGE_MOD
;
if
(
tid
<
HALF_THREAD_COUNT
)
{
cp_async4
(
s_gate_stage_128_staged_ptr
,
gate_128_ptr
);
gate_128_ptr
+=
stride_i_t_128
;
}
else
{
cp_async4
(
s_up_stage_128_staged_ptr
,
up_128_ptr
);
up_128_ptr
+=
stride_i_t_128
;
}
++
t_load
;
++
load_stage_id
;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence
();
};
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_STAGES
-
1
;
i
++
)
{
load_and_advance_y_pred
();
}
__int64_t
*
s_gate_ptr
=
reinterpret_cast
<
__int64_t
*>
(
s_buff_compute_32
+
warp_id
*
(
GROUP_SIZE
/
2
))
+
lane_id
;
__int64_t
*
s_up_ptr
=
s_gate_ptr
+
S_NUM_64
/
2
;
static
constexpr
int32_t
STAGE_SIZE
=
(
GROUP_SIZE
*
NUM_WARPS
)
/
4u
;
static
constexpr
int32_t
STAGE_MOD
=
STAGE_SIZE
*
NUM_STAGES
;
int32_t
compute_pipeline_offset_64
=
0
;
for
(
int32_t
t
=
n_tokens_lower
;
t
<
n_tokens_upper
;
++
t
)
{
__nv_bfloat16
y_max_bf16
=
EPS
;
__nv_bfloat162
results_bf162
[
2
];
cp_async_wait
<
NUM_STAGES
-
2
>
();
__syncthreads
();
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred
();
auto
s_gate_compute_64
=
s_gate_ptr
+
compute_pipeline_offset_64
;
auto
s_up_compute_64
=
s_up_ptr
+
compute_pipeline_offset_64
;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64
+=
STAGE_SIZE
;
compute_pipeline_offset_64
%=
STAGE_MOD
;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t
gate64
=
*
s_gate_compute_64
;
__nv_bfloat162
*
s_gate_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
gate64
);
__int64_t
up64
=
*
s_up_compute_64
;
__nv_bfloat162
*
s_up_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
up64
);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
// For silu, we make sure that div is emitted.
float2
gate
=
silu2
(
__bfloat1622float2
(
s_gate_compute_32
[
i
]));
results_bf162
[
i
]
=
__float22bfloat162_rn
(
gate
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
results_bf162
[
i
]
=
__hmul2
(
results_bf162
[
i
],
s_up_compute_32
[
i
]);
}
auto
_y_max2
=
__hmax2
(
__habs2
(
results_bf162
[
0
]),
__habs2
(
results_bf162
[
1
]));
y_max_bf16
=
__hmax
(
_y_max2
.
x
,
_y_max2
.
y
);
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16
y_s
=
warp_max
(
y_max_bf16
)
/
fp8_max
;
if
constexpr
(
USE_UE8M0
)
{
y_s
=
hexp2
(
hceil
(
hlog2
(
y_s
)));
}
auto
inv_y
=
__float2bfloat16_rn
(
1.
f
)
/
y_s
;
auto
y_s2
=
make_bfloat162
(
inv_y
,
inv_y
);
#pragma unroll
for
(
int32_t
i
=
0
;
i
<
2
;
++
i
)
{
results_bf162
[
i
]
=
clip
(
__hmul2
(
results_bf162
[
i
],
y_s2
),
__bfloat162bfloat162
(
fp8_min
),
__bfloat162bfloat162
(
fp8_max
));
}
auto
fp8x4
=
__nv_fp8x4_e4m3
(
results_bf162
[
0
],
results_bf162
[
1
]);
*
reinterpret_cast
<
__nv_fp8x4_e4m3
*>
(
y_q_ptr
)
=
fp8x4
;
y_q_ptr
+=
stride_yq_t
;
if
(
lane_id
==
0
)
{
*
y_s_ptr
=
y_s
;
y_s_ptr
+=
stride_ys_t
;
}
}
}
#endif
}
// namespace vllm
// Launch activation, gating, and quantize kernel.
...
...
@@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
silu_mul_fp8_quant_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
)
{
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
256
==
0
);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK
(
1
<=
num_parallel_tokens
&&
num_parallel_tokens
<=
64
);
TORCH_CHECK
(
!
(
num_parallel_tokens
&
(
num_parallel_tokens
-
1
)));
using
Idx_t
=
int64_t
;
Idx_t
E
=
input
.
size
(
0
);
Idx_t
T
=
input
.
size
(
1
);
Idx_t
H
=
input
.
size
(
2
)
/
2
;
Idx_t
stride_i_e
=
input
.
stride
(
0
);
Idx_t
stride_i_t
=
input
.
stride
(
1
);
Idx_t
stride_i_h
=
input
.
stride
(
2
);
Idx_t
stride_yq_e
=
y_q
.
stride
(
0
);
Idx_t
stride_yq_t
=
y_q
.
stride
(
1
);
Idx_t
stride_yq_h
=
y_q
.
stride
(
2
);
Idx_t
stride_ys_e
=
y_s
.
stride
(
0
);
Idx_t
stride_ys_t
=
y_s
.
stride
(
1
);
Idx_t
stride_ys_g
=
y_s
.
stride
(
2
);
Idx_t
stride_counts_e
=
counts
.
stride
(
0
);
static
constexpr
int
GROUP_SIZE
=
128
;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t
G
;
dim3
block
,
grid
;
auto
populate_launch_params
=
[
&
](
int
num_warps
,
int
_num_parallel_tokens
)
{
G
=
H
/
Idx_t
(
group_size
*
num_warps
);
grid
=
dim3
(
E
*
G
,
_num_parallel_tokens
);
block
=
dim3
(
num_warps
*
WARP_SIZE
);
};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
VLLM_DISPATCH_FP8_TYPES
(
y_q
.
scalar_type
(),
"silu_mul_fp8_quant_deep_gemm_kernel"
,
[
&
]
{
KERNEL_CALL_TOP_LEVEL
});
#endif
}
csrc/torch_bindings.cpp
View file @
98229db2
...
...
@@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#define stride_tag
#endif
ops
.
def
(
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s, int group_size, "
"bool use_ue8m0, int num_parallel_tokens) -> ()"
);
ops
.
impl
(
"silu_mul_fp8_quant_deep_gemm_cuda"
,
torch
::
kCUDA
,
&
silu_mul_fp8_quant_deep_gemm_cuda
);
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
...
...
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
View file @
98229db2
...
...
@@ -5,28 +5,52 @@ import pytest
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
silu_mul_fp8_quant_deep_gemm
)
silu_mul_fp8_quant_deep_gemm
_cuda
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
fp8_dtype
=
torch
.
float8_e4m3fn
# (E, T, H, group_size, seed)
CASES
=
[
(
1
,
1
,
128
,
64
,
0
),
(
1
,
4
,
128
,
128
,
0
),
(
2
,
4
,
256
,
128
,
0
),
(
32
,
64
,
256
,
128
,
0
),
(
17
,
31
,
768
,
128
,
0
),
(
1
,
1
,
128
,
fp8_dtype
),
(
1
,
4
,
128
,
fp8_dtype
),
(
2
,
4
,
256
,
fp8_dtype
),
(
32
,
64
,
256
,
fp8_dtype
),
(
17
,
31
,
768
,
fp8_dtype
),
(
1
,
1
,
128
*
1
,
fp8_dtype
),
(
1
,
1
,
128
*
2
,
fp8_dtype
),
(
1
,
1
,
128
*
3
,
fp8_dtype
),
(
1
,
1
,
128
*
4
,
fp8_dtype
),
(
8
,
16
,
128
*
1
,
fp8_dtype
),
(
8
,
16
,
128
*
2
,
fp8_dtype
),
(
8
,
16
,
128
*
3
,
fp8_dtype
),
(
8
,
16
,
128
*
4
,
fp8_dtype
),
(
8
,
64
,
7168
,
fp8_dtype
),
(
8
,
128
,
7168
,
fp8_dtype
),
(
8
,
256
,
7168
,
fp8_dtype
),
(
8
,
512
,
7168
,
fp8_dtype
),
(
8
,
1024
,
7168
,
fp8_dtype
),
(
256
,
8
,
7168
,
fp8_dtype
),
(
256
,
16
,
7168
,
fp8_dtype
),
(
256
,
32
,
7168
,
fp8_dtype
),
(
256
,
64
,
7168
,
fp8_dtype
),
# Only add a few fnuz tests to help with long CI times.
(
8
,
512
,
7168
,
torch
.
float8_e4m3fnuz
),
(
8
,
1024
,
7168
,
torch
.
float8_e4m3fnuz
),
]
@
pytest
.
mark
.
parametrize
(
"E,T,H,
group_size,seed
"
,
CASES
)
@
pytest
.
mark
.
parametrize
(
"E,T,H,
fp8_type
"
,
CASES
)
@
torch
.
inference_mode
()
def
test_silu_mul_fp8_quant_deep_gemm
(
E
,
T
,
H
,
group_size
,
seed
):
current_platform
.
seed_everything
(
seed
)
def
test_silu_mul_fp8_quant_deep_gemm
(
E
,
T
,
H
,
fp8_type
):
group_size
=
128
current_platform
.
seed_everything
(
42
)
# Input tensor of shape (E, T, 2*H)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
randint
(
low
=
0
,
low
=
T
//
2
,
high
=
T
,
size
=
(
E
,
),
dtype
=
torch
.
int32
,
...
...
@@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
)
# Run the Triton kernel
y_q
,
y_s
=
silu_mul_fp8_quant_deep_gemm
(
y
,
y_q
,
y_s
=
silu_mul_fp8_quant_deep_gemm
_cuda
(
y
,
tokens_per_expert
,
group_size
=
group_size
,
eps
=
1e-10
)
group_size
=
group_size
)
# Reference implementation
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
torch
.
cuda
.
synchronize
()
fp8_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
fp8_info
.
max
fp8_min
=
fp8_info
.
min
eps
=
1e-10
# Compute silu activation and elementwise multiplication
y1
=
y
[...,
:
H
]
y1
=
y
[...,
:
H
].
float
()
y2
=
y
[...,
H
:]
silu_x
=
y1
*
torch
.
sigmoid
(
y1
)
merged
=
silu_x
*
y2
# Compute reference scales and quantized output, skipping padded tokens
for
e
in
range
(
E
):
nt
=
tokens_per_expert
[
e
].
item
()
ref_s
=
torch
.
empty
((
T
,
H
//
group_size
),
ref_s
=
torch
.
empty
((
T
,
cdiv
(
H
,
group_size
)
)
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
ref_q
=
torch
.
empty
((
T
,
H
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
)
ref_q
=
torch
.
empty
((
T
,
H
),
dtype
=
fp8_dtype
,
device
=
"cuda"
)
for
t
in
range
(
nt
):
data
=
merged
[
e
,
t
]
data_grp
=
data
.
view
(
H
//
group_size
,
group_size
)
data
=
merged
[
e
,
t
].
float
()
ref_q_row
=
torch
.
empty_like
(
data
)
# process full groups
n_full_groups
=
H
//
group_size
if
n_full_groups
>
0
:
data_grp
=
data
[:
n_full_groups
*
group_size
].
view
(
n_full_groups
,
group_size
)
amax
=
data_grp
.
abs
().
amax
(
dim
=
1
).
clamp
(
min
=
eps
)
scale
=
amax
/
fp8_max
scaled
=
data
[:
n_full_groups
*
group_size
]
/
scale
.
repeat_interleave
(
group_size
)
ref_q_row
[:
n_full_groups
*
group_size
]
=
scaled
.
clamp
(
fp8_min
,
fp8_max
).
to
(
fp8_dtype
)
ref_s
[
t
,
:
n_full_groups
]
=
scale
scaled
=
data
/
scale
.
repeat_interleave
(
group_size
)
clamped
=
scaled
.
clamp
(
fp8_min
,
fp8_max
)
q
=
clamped
.
to
(
torch
.
float8_e4m3fn
)
# process remainder group
rem
=
H
%
group_size
if
rem
>
0
:
data_rem
=
data
[
-
rem
:]
amax
=
data_rem
.
abs
().
amax
().
clamp
(
min
=
eps
)
scale
=
amax
/
fp8_max
scaled
=
data_rem
/
scale
ref_q_row
[
-
rem
:]
=
scaled
.
clamp
(
fp8_min
,
fp8_max
).
to
(
fp8_dtype
)
ref_s
[
t
,
-
1
]
=
scale
ref_s
[
t
]
=
scale
ref_q
[
t
]
=
q
ref_q
[
t
]
=
ref_q_row
y_se
=
y_s
[
e
]
y_qe
=
y_q
[
e
]
y_se
=
y_s
[
e
]
.
float
()
y_qe
=
y_q
[
e
]
.
float
()
torch
.
testing
.
assert_close
(
y_se
[:
nt
],
ref_s
[:
nt
],
atol
=
1e-4
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
98229db2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
math
import
log2
from
typing
import
Optional
import
torch
...
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
,
is_deep_gemm_e8m0_used
)
...
...
@@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm(
y_q_ptr
,
# fp8 quantized activations (E, T, H)
y_s_ptr
,
# 16-bit scales (E, T, G)
counts_ptr
,
# int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H
:
tl
.
constexpr
,
# hidden dimension (per output)
GROUP_SIZE
:
tl
.
constexpr
,
# elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e
,
stride_i_t
,
stride_i_h
,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e
,
stride_yq_t
,
stride_yq_h
,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
# Stride for counts (elements)
stride_counts_e
,
# Numeric params ------------------------------------------------------
eps
:
tl
.
constexpr
,
fp8_min
:
tl
.
constexpr
,
fp8_max
:
tl
.
constexpr
,
use_ue8m0
:
tl
.
constexpr
,
# Meta ---------------------------------------------------------------
BLOCK
:
tl
.
constexpr
,
NUM_STAGES
:
tl
.
constexpr
,
...
...
@@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm(
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm
(
def
silu_mul_fp8_quant_deep_gemm
_cuda
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
num_parallel_tokens
=
16
,
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
...
...
@@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm(
E
,
T
,
H2
=
y
.
shape
assert
H2
%
2
==
0
,
"last dim of y must be even (2*H)"
H
=
H2
//
2
G
=
H
//
group_size
assert
H
%
group_size
==
0
,
"H must be divisible by group_size"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
,
\
"tokens_per_expert must be shape (E,)"
G
=
(
H
+
group_size
-
1
)
//
group_size
assert
H
%
8
==
0
,
"H must be divisible by 8"
assert
group_size
==
128
,
"H must be divisible by 8"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
tokens_per_expert
=
tokens_per_expert
.
to
(
device
=
y
.
device
,
dtype
=
torch
.
int32
)
# allocate outputs
fp8_dtype
=
torch
.
float8_e4m3fn
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_g
=
T
...
...
@@ -144,16 +132,56 @@ def silu_mul_fp8_quant_deep_gemm(
dtype
=
torch
.
float32
,
device
=
y
.
device
)
use_ue8m0
=
is_deep_gemm_e8m0_used
()
if
E
<=
16
:
max_empirical_parallelism
=
64
elif
E
<=
32
:
max_empirical_parallelism
=
16
else
:
max_empirical_parallelism
=
4
# We never want to launch more than Tx number of threads
# This computes the clip.
num_parallel_tokens
=
max
(
1
,
min
(
max_empirical_parallelism
,
2
**
int
(
log2
(
min
(
num_parallel_tokens
,
T
)))))
cuda_arch
=
current_platform
.
get_device_capability
(
device_id
=
y
.
device
.
index
).
to_int
()
if
cuda_arch
>=
80
:
torch
.
ops
.
_C
.
silu_mul_fp8_quant_deep_gemm_cuda
(
y
,
tokens_per_expert
,
y_q
,
y_s
,
group_size
,
use_ue8m0
,
num_parallel_tokens
)
else
:
# Default to triton if not on cuda or if arch is too old
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_g
=
T
y_s
=
torch
.
empty_strided
(
(
E
,
T
,
G
),
(
stride_ys_e
,
stride_ys_t
,
stride_ys_g
),
dtype
=
torch
.
float32
,
device
=
y
.
device
,
)
f_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
f_info
.
max
fp8_min
=
f_info
.
min
eps
:
float
=
1e-10
_silu_mul_fp8_quant_deep_gemm
[
grid
](
y
,
y_q
,
...
...
@@ -184,7 +212,6 @@ def silu_mul_fp8_quant_deep_gemm(
class
BatchedDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE
:
list
[
int
]
=
[
128
,
128
]
...
...
@@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
fp8_m_grouped_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
)
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
expert_num_tokens
)
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
_cuda
(
workspace1
,
expert_num_tokens
)
fp8_m_grouped_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment