Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6c4b87f
Unverified
Commit
a6c4b87f
authored
Jun 24, 2025
by
Wentao Ye
Committed by
GitHub
Jun 24, 2025
Browse files
Revert "[Feature] Integrate new deepgemm (#19820)" (#20049)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
1afa9948
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
257 additions
and
233 deletions
+257
-233
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+0
-3
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
...hmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+174
-133
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+21
-2
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+41
-14
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+11
-8
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+6
-4
vllm/model_executor/layers/quantization/deepgemm.py
vllm/model_executor/layers/quantization/deepgemm.py
+1
-1
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+3
-68
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
a6c4b87f
...
@@ -86,9 +86,6 @@ def benchmark_config(
...
@@ -86,9 +86,6 @@ def benchmark_config(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_deep_gemm
:
# we use the default block shape for deepgemm
block_quant_shape
=
[
128
,
128
]
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
if
block_quant_shape
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
...
...
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
View file @
a6c4b87f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501
import
time
import
time
# Import DeepGEMM functions
import
deep_gemm
import
torch
import
torch
from
deep_gemm
import
fp8_gemm_nt
from
deep_gemm
import
calc_diff
,
ceil_div
,
get_col_major_tma_aligned_tensor
from
deep_gemm.testing.numeric
import
calc_diff
from
deep_gemm.utils.math
import
ceil_div
,
per_block_cast_to_fp8
,
per_token_cast_to_fp8
# Import vLLM functions
# Import vLLM functions
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -16,84 +18,96 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -16,84 +18,96 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Convert tensor to FP8 format with per-token scaling."""
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
# Copied from
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
def
per_block_cast_to_fp8_vllm
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Convert tensor to FP8 format with per-block scaling."""
"""Convert tensor to FP8 format with per-block scaling."""
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
x_padded
=
torch
.
zeros
(
(
ceil_div
(
m
,
128
)
*
128
,
ceil_div
(
n
,
128
)
*
128
),
(
ceil_div
(
m
,
128
)
*
128
,
ceil_div
(
n
,
128
)
*
128
),
dtype
=
x
.
dtype
,
device
=
x
.
device
dtype
=
x
.
dtype
,
)
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
)
def
benchmark_shape
(
def
benchmark_shape
(
m
:
int
,
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
warmup
:
int
=
100
,
warmup
:
int
=
100
,
repeat
:
int
=
10000
,
repeat
:
int
=
10000
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
)
->
dict
:
)
->
dict
:
"""Benchmark all implementations for a specific (m, n, k) shape."""
"""Benchmark all implementations for a specific (m, n, k) shape."""
if
verbose
:
if
verbose
:
print
(
f
"
\n
=== Benchmarking shape: m=
{
m
}
, n=
{
n
}
, k=
{
k
}
==="
)
print
(
f
"
\n
=== Benchmarking shape: m=
{
m
}
, n=
{
n
}
, k=
{
k
}
==="
)
A
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Create test tensors
B
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
A
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
B
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
# Reference result in BF16
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
C_ref
=
A
@
B
.
t
()
C_ref
=
A
@
B
.
t
()
# Pre-quantize B for all implementations
# Pre-quantize B for all implementations
# (weights can be pre-quantized offline)
# (weights can be pre-quantized offline)
B_deepgemm
,
B_scale_deepgemm
=
per_block_cast_to_fp8
(
B
)
B_deepgemm
,
B_scale_deepgemm
=
per_block_cast_to_fp8
(
B
)
B_vllm
,
B_scale_vllm
=
per_block_cast_to_fp8
_vllm
(
B
)
B_vllm
,
B_scale_vllm
=
per_block_cast_to_fp8
(
B
)
# Block size configuration
# Block size configuration
block_size
=
[
128
,
128
]
block_size
=
[
128
,
128
]
# Pre-quantize A for all implementations
# Pre-quantize A for all implementations
A_deepgemm
,
A_scale_deepgemm
=
per_token_cast_to_fp8
(
A
)
A_deepgemm
,
A_scale_deepgemm
=
per_token_cast_to_fp8
(
A
)
C_deepgemm
=
(
A_scale_deepgemm
=
get_col_major_tma_aligned_tensor
(
A_scale_deepgemm
)
torch
.
empty
((
n
,
m
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
t
().
contiguous
()
C_deepgemm
=
torch
.
empty
((
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
)
A_vllm
,
A_scale_vllm
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
])
A_vllm
,
A_scale_vllm
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
])
A_vllm_cutlass
,
A_scale_vllm_cutlass
=
per_token_group_quant_fp8
(
A_vllm_cutlass
,
A_scale_vllm_cutlass
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
],
column_major_scales
=
True
A
,
block_size
[
1
],
column_major_scales
=
True
)
)
# === DeepGEMM Implementation ===
def
deepgemm_gemm
():
def
deepgemm_gemm
():
fp8
_gemm
_nt
(
deep
_gemm
.
gemm_fp8_fp8_bf16_nt
((
A_deepgemm
,
A_scale_deepgemm
),
(
A_deepgemm
,
A_scale_deepgemm
),
(
B_deepgemm
,
B_scale_deepgemm
),
C_deepgemm
(
B_deepgemm
,
B_scale_deepgemm
),
)
C_deepgemm
)
return
C_deepgemm
return
C_deepgemm
# === vLLM Triton Implementation ===
def
vllm_triton_gemm
():
def
vllm_triton_gemm
():
return
w8a8_block_fp8_matmul
(
return
w8a8_block_fp8_matmul
(
A_vllm
,
A_vllm
,
B_vllm
,
B_vllm
,
A_scale_vllm
,
A_scale_vllm
,
B_scale_vllm
,
B_scale_vllm
,
block_size
,
block_size
,
output_dtype
=
torch
.
bfloat16
,
output_dtype
=
torch
.
bfloat16
)
)
# === vLLM CUTLASS Implementation ===
def
vllm_cutlass_gemm
():
def
vllm_cutlass_gemm
():
return
ops
.
cutlass_scaled_mm
(
return
ops
.
cutlass_scaled_mm
(
A_vllm_cutlass
,
A_vllm_cutlass
,
B_vllm
.
T
,
B_vllm
.
T
,
scale_a
=
A_scale_vllm_cutlass
,
scale_a
=
A_scale_vllm_cutlass
,
scale_b
=
B_scale_vllm
.
T
,
scale_b
=
B_scale_vllm
.
T
,
out_dtype
=
torch
.
bfloat16
,
out_dtype
=
torch
.
bfloat16
)
)
# Run correctness check first
if
verbose
:
if
verbose
:
print
(
"Running correctness check..."
)
print
(
"Running correctness check..."
)
C_deepgemm
=
deepgemm_gemm
()
C_deepgemm
=
deepgemm_gemm
()
...
@@ -108,22 +122,26 @@ def benchmark_shape(
...
@@ -108,22 +122,26 @@ def benchmark_shape(
print
(
f
"DeepGEMM vs Reference difference:
{
deepgemm_diff
:.
6
f
}
"
)
print
(
f
"DeepGEMM vs Reference difference:
{
deepgemm_diff
:.
6
f
}
"
)
print
(
f
"vLLM Triton vs Reference difference:
{
vllm_triton_diff
:.
6
f
}
"
)
print
(
f
"vLLM Triton vs Reference difference:
{
vllm_triton_diff
:.
6
f
}
"
)
print
(
f
"vLLM CUTLASS vs Reference difference:
{
vllm_cutlass_diff
:.
6
f
}
"
)
print
(
f
"vLLM CUTLASS vs Reference difference:
{
vllm_cutlass_diff
:.
6
f
}
"
)
print
(
print
(
"vLLM Triton vs DeepGEMM difference: "
"vLLM Triton vs DeepGEMM difference: "
f
"
{
calc_diff
(
C_vllm_triton
,
C_deepgemm
):.
6
f
}
"
)
f
"
{
calc_diff
(
C_vllm_triton
,
C_deepgemm
):.
6
f
}
"
print
(
"vLLM CUTLASS vs DeepGEMM difference: "
)
f
"
{
calc_diff
(
C_vllm_cutlass
,
C_deepgemm
):.
6
f
}
"
)
print
(
"vLLM CUTLASS vs DeepGEMM difference: "
f
"
{
calc_diff
(
C_vllm_cutlass
,
C_deepgemm
):.
6
f
}
"
)
# Benchmark implementations
implementations
=
{
implementations
=
{
"DeepGEMM"
:
deepgemm_gemm
,
"DeepGEMM"
:
deepgemm_gemm
,
"vLLM Triton"
:
vllm_triton_gemm
,
"vLLM Triton"
:
vllm_triton_gemm
,
"vLLM CUTLASS"
:
vllm_cutlass_gemm
,
"vLLM CUTLASS"
:
vllm_cutlass_gemm
}
}
benchmark_results
=
{
"shape"
:
{
"m"
:
m
,
"n"
:
n
,
"k"
:
k
},
"implementations"
:
{}}
benchmark_results
=
{
"shape"
:
{
"m"
:
m
,
"n"
:
n
,
"k"
:
k
},
"implementations"
:
{}
}
for
name
,
func
in
implementations
.
items
():
for
name
,
func
in
implementations
.
items
():
# Warmup
# Warmup
...
@@ -151,36 +169,38 @@ def benchmark_shape(
...
@@ -151,36 +169,38 @@ def benchmark_shape(
"tflops"
:
tflops
,
"tflops"
:
tflops
,
"gb_s"
:
gb_s
,
"gb_s"
:
gb_s
,
"diff"
:
{
"diff"
:
{
"DeepGEMM"
:
0.0
"DeepGEMM"
:
if
name
==
"DeepGEMM"
0.0
if
name
==
"DeepGEMM"
else
calc_diff
(
func
(),
C_deepgemm
),
else
calc_diff
(
func
(),
C_deepgemm
),
"Reference"
:
"Reference"
:
deepgemm_diff
deepgemm_diff
if
name
==
"DeepGEMM"
else
if
name
==
"DeepGEMM"
(
vllm_triton_diff
else
(
vllm_triton_diff
if
name
==
"vLLM Triton"
else
vllm_cutlass_diff
)
,
if
name
==
"vLLM Triton"
else
vllm_cutlass_diff
)
}
,
}
}
}
if
verbose
:
if
verbose
:
print
(
f
"
{
name
}
:
{
avg_time_ms
:.
3
f
}
ms,
{
tflops
:.
2
f
}
TFLOPS,
{
gb_s
:.
2
f
}
GB/s"
)
print
(
f
"
{
name
}
:
{
avg_time_ms
:.
3
f
}
ms,
{
tflops
:.
2
f
}
TFLOPS,
{
gb_s
:.
2
f
}
GB/s"
)
# Calculate speedups
# Calculate speedups
baseline
=
benchmark_results
[
"implementations"
][
"DeepGEMM"
][
"time_ms"
]
baseline
=
benchmark_results
[
"implementations"
][
"DeepGEMM"
][
"time_ms"
]
for
name
,
data
in
benchmark_results
[
"implementations"
].
items
():
for
name
,
data
in
benchmark_results
[
"implementations"
].
items
():
if
name
!=
"DeepGEMM"
:
if
name
!=
"DeepGEMM"
:
speedup
=
baseline
/
data
[
"time_ms"
]
speedup
=
baseline
/
data
[
"time_ms"
]
benchmark_results
[
"implementations"
][
name
][
"speedup_vs_deepgemm"
]
=
speedup
benchmark_results
[
"implementations"
][
name
][
"speedup_vs_deepgemm"
]
=
speedup
if
verbose
:
if
verbose
:
print
(
print
(
f
"DeepGEMM is
{
1
/
speedup
:.
2
f
}
x "
f
"DeepGEMM is
{
1
/
speedup
:.
2
f
}
x "
f
"
{
'faster'
if
1
/
speedup
>
1
else
'slower'
}
than
{
name
}
"
)
f
"
{
'faster'
if
1
/
speedup
>
1
else
'slower'
}
than
{
name
}
"
)
vllm_triton_time
=
benchmark_results
[
"implementations"
][
"vLLM Triton"
][
"time_ms"
]
vllm_triton_time
=
benchmark_results
[
"implementations"
][
"vLLM Triton"
][
vllm_cutlass_time
=
benchmark_results
[
"implementations"
][
"vLLM CUTLASS"
][
"time_ms"
]
"time_ms"
]
vllm_cutlass_time
=
benchmark_results
[
"implementations"
][
"vLLM CUTLASS"
][
"time_ms"
]
cutlass_vs_triton
=
vllm_triton_time
/
vllm_cutlass_time
cutlass_vs_triton
=
vllm_triton_time
/
vllm_cutlass_time
benchmark_results
[
"implementations"
][
"vLLM CUTLASS"
][
"speedup_vs_triton"
]
=
(
benchmark_results
[
"implementations"
][
"vLLM CUTLASS"
][
cutlass_vs_triton
"speedup_vs_triton"
]
=
cutlass_vs_triton
)
if
verbose
:
if
verbose
:
print
(
print
(
f
"vLLM CUTLASS is
{
cutlass_vs_triton
:.
2
f
}
x "
f
"vLLM CUTLASS is
{
cutlass_vs_triton
:.
2
f
}
x "
...
@@ -192,7 +212,8 @@ def benchmark_shape(
...
@@ -192,7 +212,8 @@ def benchmark_shape(
def
format_table_row
(
values
,
widths
):
def
format_table_row
(
values
,
widths
):
"""Format a row with specified column widths."""
"""Format a row with specified column widths."""
return
"| "
+
" | "
.
join
(
f
"
{
val
:
{
w
}}
"
for
val
,
w
in
zip
(
values
,
widths
))
+
" |"
return
"| "
+
" | "
.
join
(
f
"
{
val
:
{
w
}}
"
for
val
,
w
in
zip
(
values
,
widths
))
+
" |"
def
print_table
(
headers
,
rows
,
title
=
None
):
def
print_table
(
headers
,
rows
,
title
=
None
):
...
@@ -200,12 +221,16 @@ def print_table(headers, rows, title=None):
...
@@ -200,12 +221,16 @@ def print_table(headers, rows, title=None):
if
title
:
if
title
:
print
(
f
"
\n
{
title
}
"
)
print
(
f
"
\n
{
title
}
"
)
# Calculate column widths based on headers and data
widths
=
[
widths
=
[
max
(
len
(
str
(
h
)),
max
(
len
(
str
(
row
[
i
]))
for
row
in
rows
))
max
(
len
(
str
(
h
)),
max
(
len
(
str
(
row
[
i
]))
for
row
in
rows
))
for
i
,
h
in
enumerate
(
headers
)
for
i
,
h
in
enumerate
(
headers
)
]
]
# Create separator line
separator
=
"+-"
+
"-+-"
.
join
(
"-"
*
w
for
w
in
widths
)
+
"-+"
separator
=
"+-"
+
"-+-"
.
join
(
"-"
*
w
for
w
in
widths
)
+
"-+"
# Print table
print
(
separator
)
print
(
separator
)
print
(
format_table_row
(
headers
,
widths
))
print
(
format_table_row
(
headers
,
widths
))
print
(
separator
)
print
(
separator
)
...
@@ -223,22 +248,44 @@ def run_benchmarks(verbose: bool = False):
...
@@ -223,22 +248,44 @@ def run_benchmarks(verbose: bool = False):
"""Run benchmarks for a set of common shapes."""
"""Run benchmarks for a set of common shapes."""
print
(
"===== STARTING FP8 GEMM BENCHMARK ====="
)
print
(
"===== STARTING FP8 GEMM BENCHMARK ====="
)
# Make sure we're using the GPU
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
print
(
"CUDA not available! Tests require GPU."
)
print
(
"CUDA not available! Tests require GPU."
)
return
return
# Print system information
print
(
f
"PyTorch version:
{
torch
.
__version__
}
"
)
print
(
f
"PyTorch version:
{
torch
.
__version__
}
"
)
print
(
f
"CUDA version:
{
torch
.
version
.
cuda
}
"
)
print
(
f
"CUDA version:
{
torch
.
version
.
cuda
}
"
)
print
(
f
"Triton version:
{
triton
.
__version__
}
"
)
print
(
f
"Triton version:
{
triton
.
__version__
}
"
)
print
(
f
"Using device:
{
torch
.
cuda
.
get_device_name
()
}
"
)
print
(
f
"Using device:
{
torch
.
cuda
.
get_device_name
()
}
"
)
# Enable TF32 for better performance
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Set seeds for reproducibility
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
# Define benchmark shapes (m, n, k)
# Define benchmark shapes (m, n, k)
shapes
=
[
(
8
,
4096
,
7168
),
(
8
,
7168
,
18432
),
(
8
,
18432
,
7168
),
(
64
,
4096
,
7168
),
(
64
,
7168
,
18432
),
(
64
,
18432
,
7168
),
(
64
,
24576
,
1536
),
(
64
,
32768
,
512
),
(
64
,
7168
,
16384
),
(
128
,
4096
,
7168
),
(
128
,
7168
,
18432
),
(
128
,
18432
,
7168
),
(
1024
,
4096
,
7168
),
(
1024
,
18432
,
7168
),
(
2048
,
4096
,
7168
),
(
4096
,
4096
,
7168
),
]
shapes
=
[
shapes
=
[
# (64, 2112, 7168),
# (64, 2112, 7168),
(
64
,
24576
,
1536
),
(
64
,
24576
,
1536
),
...
@@ -265,6 +312,7 @@ def run_benchmarks(verbose: bool = False):
...
@@ -265,6 +312,7 @@ def run_benchmarks(verbose: bool = False):
result
=
benchmark_shape
(
m
,
n
,
k
,
verbose
=
verbose
)
result
=
benchmark_shape
(
m
,
n
,
k
,
verbose
=
verbose
)
all_results
.
append
(
result
)
all_results
.
append
(
result
)
# Print results in a nicely formatted table
print
(
"
\n
===== PERFORMANCE COMPARISON ====="
)
print
(
"
\n
===== PERFORMANCE COMPARISON ====="
)
# Print DeepGEMM table
# Print DeepGEMM table
...
@@ -273,50 +321,38 @@ def run_benchmarks(verbose: bool = False):
...
@@ -273,50 +321,38 @@ def run_benchmarks(verbose: bool = False):
for
result
in
all_results
:
for
result
in
all_results
:
shape
=
result
[
"shape"
]
shape
=
result
[
"shape"
]
impl_data
=
result
[
"implementations"
][
"DeepGEMM"
]
impl_data
=
result
[
"implementations"
][
"DeepGEMM"
]
deepgemm_rows
.
append
(
deepgemm_rows
.
append
([
[
shape
[
"m"
],
shape
[
"n"
],
shape
[
"k"
],
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
shape
[
"m"
],
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
shape
[
"n"
],
])
shape
[
"k"
],
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
,
]
)
print_table
(
deepgemm_headers
,
deepgemm_rows
,
title
=
"DeepGEMM Implementation:"
)
print_table
(
deepgemm_headers
,
deepgemm_rows
,
title
=
"DeepGEMM Implementation:"
)
# Print vLLM Triton table
# Print vLLM Triton table
triton_headers
=
[
"m"
,
"n"
,
"k"
,
"Time (μs)"
,
"TFLOPS"
,
"GB/s"
,
"vs DeepGEMM"
]
triton_headers
=
[
"m"
,
"n"
,
"k"
,
"Time (μs)"
,
"TFLOPS"
,
"GB/s"
,
"vs DeepGEMM"
]
triton_rows
=
[]
triton_rows
=
[]
for
result
in
all_results
:
for
result
in
all_results
:
shape
=
result
[
"shape"
]
shape
=
result
[
"shape"
]
impl_data
=
result
[
"implementations"
][
"vLLM Triton"
]
impl_data
=
result
[
"implementations"
][
"vLLM Triton"
]
speedup
=
impl_data
.
get
(
"speedup_vs_deepgemm"
,
1.0
)
speedup
=
impl_data
.
get
(
"speedup_vs_deepgemm"
,
1.0
)
triton_rows
.
append
(
triton_rows
.
append
([
[
shape
[
"m"
],
shape
[
"n"
],
shape
[
"k"
],
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
shape
[
"m"
],
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
,
shape
[
"n"
],
format_speedup
(
speedup
)
shape
[
"k"
],
])
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
,
format_speedup
(
speedup
),
]
)
print_table
(
triton_headers
,
triton_rows
,
title
=
"vLLM Triton Implementation:"
)
print_table
(
triton_headers
,
triton_rows
,
title
=
"vLLM Triton Implementation:"
)
# Print vLLM CUTLASS table
# Print vLLM CUTLASS table
cutlass_headers
=
[
cutlass_headers
=
[
"m"
,
"m"
,
"n"
,
"k"
,
"Time (μs)"
,
"TFLOPS"
,
"GB/s"
,
"vs DeepGEMM"
,
"n"
,
"vs Triton"
"k"
,
"Time (μs)"
,
"TFLOPS"
,
"GB/s"
,
"vs DeepGEMM"
,
"vs Triton"
,
]
]
cutlass_rows
=
[]
cutlass_rows
=
[]
for
result
in
all_results
:
for
result
in
all_results
:
...
@@ -324,27 +360,28 @@ def run_benchmarks(verbose: bool = False):
...
@@ -324,27 +360,28 @@ def run_benchmarks(verbose: bool = False):
impl_data
=
result
[
"implementations"
][
"vLLM CUTLASS"
]
impl_data
=
result
[
"implementations"
][
"vLLM CUTLASS"
]
vs_deepgemm
=
impl_data
.
get
(
"speedup_vs_deepgemm"
,
1.0
)
vs_deepgemm
=
impl_data
.
get
(
"speedup_vs_deepgemm"
,
1.0
)
vs_triton
=
impl_data
.
get
(
"speedup_vs_triton"
,
1.0
)
vs_triton
=
impl_data
.
get
(
"speedup_vs_triton"
,
1.0
)
cutlass_rows
.
append
(
cutlass_rows
.
append
([
[
shape
[
"m"
],
shape
[
"n"
],
shape
[
"k"
],
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
shape
[
"m"
],
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
,
shape
[
"n"
],
shape
[
"k"
],
f
"
{
impl_data
[
'time_us'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'tflops'
]:.
1
f
}
"
,
f
"
{
impl_data
[
'gb_s'
]:.
1
f
}
"
,
format_speedup
(
vs_deepgemm
),
format_speedup
(
vs_deepgemm
),
format_speedup
(
vs_triton
),
format_speedup
(
vs_triton
)
]
])
)
print_table
(
cutlass_headers
,
cutlass_rows
,
title
=
"vLLM CUTLASS Implementation:"
)
print_table
(
cutlass_headers
,
cutlass_rows
,
title
=
"vLLM CUTLASS Implementation:"
)
# Calculate and print averages
# Calculate and print averages
print
(
"
\n
===== AVERAGE PERFORMANCE ====="
)
print
(
"
\n
===== AVERAGE PERFORMANCE ====="
)
implementations
=
[
"DeepGEMM"
,
"vLLM Triton"
,
"vLLM CUTLASS"
]
implementations
=
[
"DeepGEMM"
,
"vLLM Triton"
,
"vLLM CUTLASS"
]
avg_metrics
=
{
avg_metrics
=
{
impl
:
{
"tflops"
:
0
,
"gb_s"
:
0
,
"time_ms"
:
0
}
for
impl
in
implementations
impl
:
{
"tflops"
:
0
,
"gb_s"
:
0
,
"time_ms"
:
0
}
for
impl
in
implementations
}
}
for
result
in
all_results
:
for
result
in
all_results
:
...
@@ -362,9 +399,9 @@ def run_benchmarks(verbose: bool = False):
...
@@ -362,9 +399,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops
=
avg_metrics
[
impl
][
"tflops"
]
/
num_shapes
avg_tflops
=
avg_metrics
[
impl
][
"tflops"
]
/
num_shapes
avg_mem_bw
=
avg_metrics
[
impl
][
"gb_s"
]
/
num_shapes
avg_mem_bw
=
avg_metrics
[
impl
][
"gb_s"
]
/
num_shapes
avg_time
=
avg_metrics
[
impl
][
"time_ms"
]
/
num_shapes
avg_time
=
avg_metrics
[
impl
][
"time_ms"
]
/
num_shapes
avg_rows
.
append
(
avg_rows
.
append
(
[
[
impl
,
f
"
{
avg_tflops
:.
2
f
}
"
,
f
"
{
avg_mem_bw
:.
2
f
}
"
,
f
"
{
avg_time
:.
2
f
}
"
]
impl
,
f
"
{
avg_tflops
:.
2
f
}
"
,
f
"
{
avg_mem_bw
:.
2
f
}
"
,
f
"
{
avg_time
:.
2
f
}
"
)
]
)
print_table
(
avg_headers
,
avg_rows
)
print_table
(
avg_headers
,
avg_rows
)
...
@@ -372,19 +409,21 @@ def run_benchmarks(verbose: bool = False):
...
@@ -372,19 +409,21 @@ def run_benchmarks(verbose: bool = False):
avg_speedups
=
{
avg_speedups
=
{
"DeepGEMM vs vLLM Triton"
:
0
,
"DeepGEMM vs vLLM Triton"
:
0
,
"DeepGEMM vs vLLM CUTLASS"
:
0
,
"DeepGEMM vs vLLM CUTLASS"
:
0
,
"vLLM CUTLASS vs vLLM Triton"
:
0
,
"vLLM CUTLASS vs vLLM Triton"
:
0
}
}
for
result
in
all_results
:
for
result
in
all_results
:
deepgemm_time
=
result
[
"implementations"
][
"DeepGEMM"
][
"time_ms"
]
deepgemm_time
=
result
[
"implementations"
][
"DeepGEMM"
][
"time_ms"
]
vllm_triton_time
=
result
[
"implementations"
][
"vLLM Triton"
][
"time_ms"
]
vllm_triton_time
=
result
[
"implementations"
][
"vLLM Triton"
][
"time_ms"
]
vllm_cutlass_time
=
result
[
"implementations"
][
"vLLM CUTLASS"
][
"time_ms"
]
vllm_cutlass_time
=
result
[
"implementations"
][
"vLLM CUTLASS"
][
"time_ms"
]
avg_speedups
[
"DeepGEMM vs vLLM Triton"
]
+=
vllm_triton_time
/
deepgemm_time
avg_speedups
[
avg_speedups
[
"DeepGEMM vs vLLM CUTLASS"
]
+=
vllm_cutlass_time
/
deepgemm_time
"DeepGEMM vs vLLM Triton"
]
+=
vllm_triton_time
/
deepgemm_time
avg_speedups
[
"vLLM CUTLASS vs vLLM Triton"
]
+=
(
avg_speedups
[
vllm_triton_time
/
vllm_cutlass_time
"DeepGEMM vs vLLM CUTLASS"
]
+=
vllm_cutlass_time
/
deepgemm_time
)
avg_speedups
[
"vLLM CUTLASS vs vLLM Triton"
]
+=
vllm_triton_time
/
vllm_cutlass_time
print
(
"
\n
===== AVERAGE SPEEDUPS ====="
)
print
(
"
\n
===== AVERAGE SPEEDUPS ====="
)
speedup_headers
=
[
"Comparison"
,
"Speedup"
]
speedup_headers
=
[
"Comparison"
,
"Speedup"
]
...
@@ -396,12 +435,14 @@ def run_benchmarks(verbose: bool = False):
...
@@ -396,12 +435,14 @@ def run_benchmarks(verbose: bool = False):
print_table
(
speedup_headers
,
speedup_rows
)
print_table
(
speedup_headers
,
speedup_rows
)
# Average accuracy comparison
print
(
"
\n
===== ACCURACY COMPARISON ====="
)
print
(
"
\n
===== ACCURACY COMPARISON ====="
)
avg_diff
=
{
impl
:
0
for
impl
in
implementations
}
avg_diff
=
{
impl
:
0
for
impl
in
implementations
}
for
result
in
all_results
:
for
result
in
all_results
:
for
impl
in
implementations
:
for
impl
in
implementations
:
avg_diff
[
impl
]
+=
result
[
"implementations"
][
impl
][
"diff"
][
"Reference"
]
avg_diff
[
impl
]
+=
result
[
"implementations"
][
impl
][
"diff"
][
"Reference"
]
diff_headers
=
[
"Implementation"
,
"Avg Diff vs Reference"
]
diff_headers
=
[
"Implementation"
,
"Avg Diff vs Reference"
]
diff_rows
=
[]
diff_rows
=
[]
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
a6c4b87f
...
@@ -66,6 +66,25 @@ def next_power_of_2(x):
...
@@ -66,6 +66,25 @@ def next_power_of_2(x):
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
def
make_block_quant_fp8_weights
(
e
:
int
,
e
:
int
,
n
:
int
,
n
:
int
,
...
@@ -106,8 +125,8 @@ def make_block_quant_fp8_weights(
...
@@ -106,8 +125,8 @@ def make_block_quant_fp8_weights(
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
for
i
in
range
(
e
):
for
i
in
range
(
e
):
w1
[
i
],
w1_s
[
i
]
=
deep_gemm
.
utils
.
math
.
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
deep_gemm
.
utils
.
math
.
per_block_cast_to_fp8
(
w2_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
return
w1
,
w2
,
w1_s
,
w2_s
return
w1
,
w2
,
w1_s
,
w2_s
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
a6c4b87f
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
get_col_major_tma_aligned_tensor
,
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
dg_available
=
False
dg_available
=
False
...
@@ -264,6 +263,25 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -264,6 +263,25 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
assert
rel_diff
<
0.03
assert
rel_diff
<
0.03
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
...
@@ -281,8 +299,10 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
...
@@ -281,8 +299,10 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
A_fp8
,
As_fp8
=
deep_gemm
.
utils
.
math
.
per_token_cast_to_fp8
(
A_fp32
)
_
,
block_k
=
block_size
[
0
],
block_size
[
1
]
B_fp8
,
Bs_fp8
=
deep_gemm
.
utils
.
math
.
per_block_cast_to_fp8
(
B_fp32
)
A_fp8
,
As_fp8
=
per_token_group_quant_fp8
(
A_fp32
,
block_k
)
B_fp8
,
Bs_fp8
=
per_block_cast_to_fp8
(
B_fp32
)
As
=
As_fp8
.
to
(
torch
.
float32
)
As
=
As_fp8
.
to
(
torch
.
float32
)
Bs
=
Bs_fp8
.
to
(
torch
.
float32
)
Bs
=
Bs_fp8
.
to
(
torch
.
float32
)
...
@@ -290,12 +310,15 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
...
@@ -290,12 +310,15 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out
=
native_w8a8_block_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
ref_out
=
native_w8a8_block_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
out_dtype
)
out_dtype
)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
As_fp8
)
out
=
torch
.
zeros
((
M
,
N
),
device
=
'cuda'
,
dtype
=
out_dtype
)
out
=
torch
.
zeros
((
M
,
N
),
device
=
'cuda'
,
dtype
=
out_dtype
)
assert
As_fp8
.
shape
==
(
M
,
(
K
+
127
)
//
assert
As_fp8
.
shape
==
(
M
,
(
K
+
127
)
//
128
),
f
"
{
As_fp8
.
shape
}
!=
{
(
M
,
(
K
+
127
)
//
128
)
}
"
128
),
f
"
{
As_fp8
.
shape
}
!=
{
(
M
,
(
K
+
127
)
//
128
)
}
"
deep_gemm
.
fp8_
gemm_nt
((
A_fp8
,
As_fp8
),
(
B_fp8
,
Bs_fp8
),
out
)
deep_gemm
.
gemm
_fp8_fp8_bf16
_nt
((
A_fp8
,
As_fp8
),
(
B_fp8
,
Bs_fp8
),
out
)
rel_diff
=
(
torch
.
mean
(
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
...
@@ -359,7 +382,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
...
@@ -359,7 +382,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
device
=
a
.
device
)
device
=
a
.
device
)
deep_gemm
.
m_grouped_
fp8_
gemm_nt_contiguous
((
a_q
,
a_s
),
(
w1
,
w1_s
),
deep_gemm
.
m_grouped_gemm
_fp8_fp8_bf16
_nt_contiguous
((
a_q
,
a_s
),
(
w1
,
w1_s
),
inter_out
,
m_indices
)
inter_out
,
m_indices
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
...
@@ -367,8 +390,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
...
@@ -367,8 +390,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
out
=
torch
.
zeros
(
a_q
.
shape
[
0
],
K
,
dtype
=
torch
.
bfloat16
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
a_q
.
shape
[
0
],
K
,
dtype
=
torch
.
bfloat16
,
device
=
a
.
device
)
deep_gemm
.
m_grouped_
fp8_
gemm_nt_contiguous
(
(
act_out_q
,
act_out_s
),
deep_gemm
.
m_grouped_gemm_
fp8_fp8_bf16_
nt_contiguous
(
(
w2
,
w2_s
),
out
,
m_indices
)
(
act_out_q
,
act_out_s
),
(
w2
,
w2_s
),
out
,
m_indices
)
final_out
=
_moe_unpermute
(
out
,
inv_perm
,
topk
,
K
,
topk_weight
)
final_out
=
_moe_unpermute
(
out
,
inv_perm
,
topk
,
K
,
topk_weight
)
...
@@ -418,15 +441,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -418,15 +441,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
w1_s
=
torch
.
empty
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
w1_s
=
torch
.
empty
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
w1_s
=
get_col_major_tma_aligned_tensor
(
w1_s
).
contiguous
()
w1_s
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
w1_s
).
contiguous
()
w2_s
=
get_col_major_tma_aligned_tensor
(
w2_s
).
contiguous
()
w2_s
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
w2_s
).
contiguous
()
assert
w1_s
.
shape
==
(
E
,
(
2
*
N
+
127
)
//
128
,
(
K
+
127
)
//
128
)
assert
w1_s
.
shape
==
(
E
,
(
2
*
N
+
127
)
//
128
,
(
K
+
127
)
//
128
)
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
for
i
in
range
(
E
):
for
i
in
range
(
E
):
w1
[
i
],
w1_s
[
i
]
=
deep_gemm
.
utils
.
math
.
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
deep_gemm
.
utils
.
math
.
per_block_cast_to_fp8
(
w2_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
...
@@ -437,10 +460,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -437,10 +460,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
topk
,
block_size
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff
=
(
torch
.
mean
(
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
))))
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
))))
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
a6c4b87f
...
@@ -266,7 +266,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -266,7 +266,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
# may lead to better performance.
expected_m
=
max_num_tokens
expected_m
=
max_num_tokens
dg
.
fp8_m_grouped_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
out
=
workspace1
,
out
=
workspace1
,
masked_m
=
expert_num_tokens
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
expected_m
=
expected_m
)
...
@@ -275,7 +277,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -275,7 +277,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
expert_num_tokens
)
expert_num_tokens
)
dg
.
fp8_m_grouped_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
out
=
output
,
out
=
output
,
masked_m
=
expert_num_tokens
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
expected_m
=
expected_m
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
a6c4b87f
...
@@ -143,9 +143,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -143,9 +143,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M_sum
,
N
//
2
))
(
M_sum
,
N
//
2
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
# import pdb; pdb.set_trace()
dg
.
m_grouped_
fp8_gemm_nt_contiguous
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
dg
.
m_grouped_
gemm_fp8_fp8_bf16_nt_contiguous
(
mm1_out
,
expert_ids
)
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
self
.
activation
(
activation
,
act_out
,
mm1_out
.
view
(
-
1
,
N
))
self
.
activation
(
activation
,
act_out
,
mm1_out
.
view
(
-
1
,
N
))
...
@@ -154,8 +155,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -154,8 +155,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
block_shape
[
1
],
self
.
block_shape
[
1
],
column_major_scales
=
True
,
column_major_scales
=
True
,
out_q
=
quant_out
)
out_q
=
quant_out
)
dg
.
m_grouped_fp8_gemm_nt_contiguous
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
torch
.
index_select
(
mm2_out
,
0
,
inv_perm
,
out
=
output
)
torch
.
index_select
(
mm2_out
,
0
,
inv_perm
,
out
=
output
)
...
...
vllm/model_executor/layers/quantization/deepgemm.py
View file @
a6c4b87f
...
@@ -58,7 +58,7 @@ def w8a8_block_fp8_matmul_deepgemm(
...
@@ -58,7 +58,7 @@ def w8a8_block_fp8_matmul_deepgemm(
output_dtype
)
output_dtype
)
# Deepgemm only supports output tensor type as bfloat16
# Deepgemm only supports output tensor type as bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
deep_gemm
.
fp8_
gemm_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
deep_gemm
.
gemm
_fp8_fp8_bf16
_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
return
C
return
C
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
a6c4b87f
...
@@ -114,10 +114,6 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
...
@@ -114,10 +114,6 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
and
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
and
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
# TODO fix ROCm->Triton custom path:
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
# https://github.com/vllm-project/vllm/issues/14397
def
apply_w8a8_block_fp8_linear
(
def
apply_w8a8_block_fp8_linear
(
...
@@ -162,6 +158,9 @@ def apply_w8a8_block_fp8_linear(
...
@@ -162,6 +158,9 @@ def apply_w8a8_block_fp8_linear(
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
if
current_platform
.
has_device_capability
(
100
):
if
current_platform
.
has_device_capability
(
100
):
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
use_cutlass
=
cutlass_block_fp8_supported
and
(
use_cutlass
=
cutlass_block_fp8_supported
and
(
ceil_div
(
weight
.
shape
[
0
],
128
)
==
weight_scale
.
shape
[
0
]
ceil_div
(
weight
.
shape
[
0
],
128
)
==
weight_scale
.
shape
[
0
]
and
ceil_div
(
weight
.
shape
[
1
],
128
)
==
weight_scale
.
shape
[
1
])
and
ceil_div
(
weight
.
shape
[
1
],
128
)
==
weight_scale
.
shape
[
1
])
...
@@ -656,67 +655,3 @@ def w8a8_block_fp8_matmul(
...
@@ -656,67 +655,3 @@ def w8a8_block_fp8_matmul(
)
)
return
C
return
C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_tma_aligned_size
(
x
:
int
,
element_size
:
int
)
->
int
:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes
=
16
assert
tma_alignment_bytes
%
element_size
==
0
alignment
=
tma_alignment_bytes
//
element_size
return
ceil_div
(
x
,
alignment
)
*
alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert
x
.
dim
()
in
(
2
,
3
)
remove_dim
=
False
m
,
n
=
x
.
shape
[
-
2
],
x
.
shape
[
-
1
]
aligned_m
=
get_tma_aligned_size
(
m
,
x
.
element_size
())
if
x
.
dim
()
==
2
:
if
x
.
stride
(
0
)
==
1
and
x
.
stride
(
1
)
==
aligned_m
:
return
x
x
,
remove_dim
=
x
.
unsqueeze
(
0
),
True
b
=
x
.
shape
[
0
]
# The last kernel gives a column-major TMA aligned layout
if
x
.
stride
(
0
)
==
aligned_m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
aligned_m
:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment