Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e59ca942
Unverified
Commit
e59ca942
authored
Apr 01, 2025
by
bnellnm
Committed by
GitHub
Apr 01, 2025
Browse files
Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
a57a3044
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
774 additions
and
115 deletions
+774
-115
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+61
-36
tests/kernels/test_block_fp8.py
tests/kernels/test_block_fp8.py
+256
-25
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-1
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+415
-53
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+36
-0
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
e59ca942
...
...
@@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
num_stages
:
int
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
List
[
int
]
=
None
,
)
->
float
:
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
List
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
...
...
@@ -115,22 +114,41 @@ def benchmark_config(
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
)
if
use_deep_gemm
:
topk_weights
,
topk_ids
=
fused_topk
(
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
allow_deep_gemm
=
True
,
)
else
:
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
)
# JIT compilation & warmup
run
()
...
...
@@ -366,6 +384,7 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
List
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
...
...
@@ -396,7 +415,8 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
)
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
)
return
config
,
kernel_time
def
tune
(
...
...
@@ -411,6 +431,7 @@ class BenchmarkWorker:
use_int8_w8a16
:
bool
,
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
use_deep_gemm
:
bool
,
)
->
dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
...
...
@@ -436,7 +457,8 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
)
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
...
...
@@ -550,6 +572,8 @@ def main(args: argparse.Namespace):
else
:
batch_sizes
=
[
args
.
batch_size
]
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
...
...
@@ -572,10 +596,10 @@ def main(args: argparse.Namespace):
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtyp
e
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
)
for
batch_size
in
batch_sizes
])
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_spac
e
,
block_quant_shape
,
use_deep_gemm
)
for
batch_size
in
batch_sizes
])
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
...
...
@@ -589,7 +613,7 @@ def main(args: argparse.Namespace):
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
)
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
...
...
@@ -611,6 +635,7 @@ if __name__ == "__main__":
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
...
...
tests/kernels/test_block_fp8.py
View file @
e59ca942
...
...
@@ -6,12 +6,22 @@ import itertools
import
pytest
import
torch
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
deep_gemm_moe_fp8
,
fused_topk
,
moe_align_block_size
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
dg_available
=
False
try
:
import
deep_gemm
dg_available
=
True
except
ImportError
:
pass
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
...
...
@@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS
=
[
7
,
83
,
2048
]
D
=
[
512
,
4096
,
5120
,
13824
]
GROUP_SIZE
=
[
64
,
128
,
256
,
512
]
M
=
[
1
,
7
,
8
3
,
512
,
2048
]
N
=
[
128
,
512
,
1024
,
4096
,
7748
,
13824
]
K
=
[
256
,
4096
,
5120
,
3884
,
13824
]
M
=
[
1
,
7
,
8
,
83
,
84
,
512
,
2048
,
4096
]
N
=
[
128
,
512
,
1024
,
4096
,
7168
,
7748
,
13824
]
K
=
[
256
,
4096
,
5120
,
3884
,
13824
,
16384
]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
M_moe
=
[
1
,
7
,
83
,
512
,
2048
]
N_moe
=
[
4608
]
# [128, 4608, 13824]
K_moe
=
[
7168
]
# [256, 7168, 13824]
M_moe
=
[
1
,
2
,
7
,
83
,
128
,
512
,
2048
]
M_moe_dg
=
[
128
,
192
,
512
,
1335
,
2048
]
N_moe
=
[
128
,
256
,
1024
,
4608
]
# [13824]
K_moe
=
[
256
,
512
,
7168
]
# [13824]
BLOCK_SIZE
=
[[
128
,
128
]]
E
=
[
8
,
24
]
# [
8, 24,
128, 256]
TOP_KS
=
[
2
]
#
[1, 2, 6]
E
=
[
2
,
8
,
16
,
24
]
# [128, 256]
TOP_KS
=
[
1
,
2
,
6
]
OUT_DTYPES
=
[
torch
.
bfloat16
]
# [torch.float32, torch.half, torch.bfloat16]
SEEDS
=
[
0
]
...
...
@@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
SEEDS
))
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
):
if
topk
>
E
:
pytest
.
skip
(
f
"Skipping test; topk=
{
topk
}
> E=
{
E
}
"
)
torch
.
manual_seed
(
seed
)
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
vllm_config
=
VllmConfig
()
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
w1_bf16
=
(
torch
.
rand
(
...
...
@@ -246,25 +262,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
print
(
f
"
{
out
.
sum
()
=
}
"
)
print
(
f
"
{
ref_out
.
sum
()
=
}
"
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
))))
assert
rel_diff
<
0.03
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
# only aligned sizes
if
M
%
4
!=
0
or
K
%
128
!=
0
or
N
%
64
!=
0
:
pytest
.
skip
(
f
"Skipping test; invalid size
{
M
}
,
{
N
}
,
{
K
}
"
)
torch
.
manual_seed
(
seed
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
fp8_info
.
max
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
_
,
block_k
=
block_size
[
0
],
block_size
[
1
]
A_fp8
,
As_fp8
=
per_token_group_quant_fp8
(
A_fp32
,
block_k
)
B_fp8
,
Bs_fp8
=
per_block_cast_to_fp8
(
B_fp32
)
As
=
As_fp8
.
to
(
torch
.
float32
)
Bs
=
Bs_fp8
.
to
(
torch
.
float32
)
ref_out
=
native_w8a8_block_fp8_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
out_dtype
)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
As_fp8
)
out
=
torch
.
zeros
((
M
,
N
),
device
=
'cuda'
,
dtype
=
out_dtype
)
assert
As_fp8
.
shape
==
(
M
,
(
K
+
127
)
//
128
),
f
"
{
As_fp8
.
shape
}
!=
{
(
M
,
(
K
+
127
)
//
128
)
}
"
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
A_fp8
,
As_fp8
),
(
B_fp8
,
Bs_fp8
),
out
)
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
))))
assert
rel_diff
<
0.001
def
fp8_perm
(
m
,
idx
):
if
torch
.
is_floating_point
(
m
)
and
torch
.
finfo
(
m
.
dtype
).
bits
==
8
:
return
m
.
view
(
dtype
=
torch
.
uint8
)[
idx
,
...].
view
(
dtype
=
m
.
dtype
)
else
:
return
m
[
idx
,
...]
def
test_moe_permute
(
a
,
a_s
,
topk_ids
,
num_groups
,
topk
,
block_m
):
M
,
K
=
a
.
shape
sorted_token_ids
,
m_indices
,
num_pad
=
moe_align_block_size
(
topk_ids
,
block_m
,
num_groups
,
None
,
pad_sorted_ids
=
True
)
num_tokens
=
topk
*
M
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
m_indices
=
torch
.
repeat_interleave
(
m_indices
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
M
*
topk
]
a
=
fp8_perm
(
a
,
sorted_token_ids
//
topk
)
if
a_s
is
not
None
:
a_s
=
a_s
[
sorted_token_ids
//
topk
]
return
a
,
a_s
,
m_indices
,
inv_perm
def
test_moe_unpermute
(
out
,
inv_perm
,
topk
,
K
,
topk_weight
):
M
=
topk_weight
.
shape
[
0
]
out
=
out
[
inv_perm
,
...]
tmp_out
=
out
.
view
(
-
1
,
topk
,
K
)
return
(
tmp_out
*
topk_weight
.
view
(
M
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
def
deep_gemm_w8a8_block_fp8_moe
(
M
,
K
,
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups
=
w1
.
shape
[
0
]
M
,
K
=
a
.
shape
N
=
w2
.
shape
[
-
1
]
topk_weight
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
per_token_group_quant_fp8
(
a
,
block_m
)
a_q
,
a_s
,
m_indices
,
inv_perm
=
test_moe_permute
(
a_q
,
a_s
,
topk_ids
,
num_groups
,
topk
,
block_m
)
inter_out
=
torch
.
zeros
((
a_q
.
shape
[
0
],
N
*
2
),
dtype
=
torch
.
bfloat16
,
device
=
a
.
device
)
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
((
a_q
,
a_s
),
(
w1
,
w1_s
),
inter_out
,
m_indices
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
per_token_group_quant_fp8
(
act_out
,
block_k
)
out
=
torch
.
zeros
(
a_q
.
shape
[
0
],
K
,
dtype
=
torch
.
bfloat16
,
device
=
a
.
device
)
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
act_out_q
,
act_out_s
),
(
w2
,
w2_s
),
out
,
m_indices
)
final_out
=
test_moe_unpermute
(
out
,
inv_perm
,
topk
,
K
,
topk_weight
)
return
final_out
@
pytest
.
mark
.
parametrize
(
"M,N,K,E,topk,seed"
,
itertools
.
product
(
M_moe_dg
,
N_moe
,
K_moe
,
E
,
TOP_KS
,
SEEDS
))
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
):
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
# only aligned sizes
if
(
N
%
block_m
!=
0
or
K
%
block_m
!=
0
or
topk
>
E
):
pytest
.
skip
(
f
"Skipping test; bad size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
, topk=
{
topk
}
, E=
{
E
}
"
)
if
(
N
<=
512
):
pytest
.
skip
(
"Skipping N <= 512 until performance issues solved."
)
vllm_config
=
VllmConfig
()
torch
.
manual_seed
(
seed
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
w1_bf16
=
((
torch
.
rand
((
E
,
2
*
N
,
K
),
dtype
=
torch
.
bfloat16
)
-
0.5
)
*
2
*
fp8_max
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
w2_bf16
=
((
torch
.
rand
((
E
,
K
,
N
),
dtype
=
torch
.
bfloat16
)
-
0.5
)
*
2
*
fp8_max
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles_w1
=
((
2
*
N
)
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1
=
torch
.
empty_like
(
w1_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
empty_like
(
w2_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w1_s
=
torch
.
empty
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
w1_s
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
w1_s
).
contiguous
()
w2_s
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
w2_s
).
contiguous
()
assert
w1_s
.
shape
==
(
E
,
(
2
*
N
+
127
)
//
128
,
(
K
+
127
)
//
128
)
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
for
i
in
range
(
E
):
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
if
M
>=
128
:
ref_out
=
deep_gemm_w8a8_block_fp8_moe
(
M
,
K
,
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
else
:
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
))))
assert
rel_diff
<
0.03
vllm/_custom_ops.py
View file @
e59ca942
...
...
@@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indicies
:
torch
.
Tensor
,
gating_output
:
float
)
->
None
:
gating_output
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
)
...
...
vllm/envs.py
View file @
e59ca942
...
...
@@ -105,6 +105,7 @@ if TYPE_CHECKING:
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP"
:
lambda
:
int
(
os
.
environ
[
"VLLM_TPU_BUCKET_PADDING_GAP"
])
if
"VLLM_TPU_BUCKET_PADDING_GAP"
in
os
.
environ
else
0
,
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
}
# end-env-vars-definition
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
e59ca942
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import
functools
import
importlib.util
import
json
import
os
from
math
import
prod
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -15,7 +17,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
,
round_up
from
.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
rocm_aiter_fused_experts
,
...
...
@@ -23,6 +25,8 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
...
...
@@ -581,7 +585,8 @@ def moe_align_block_size(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
...
...
@@ -596,6 +601,8 @@ def moe_align_block_size(
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
...
...
@@ -625,6 +632,8 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
...
...
@@ -667,6 +676,59 @@ def moe_align_block_size(
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
expert_map
:
Optional
[
torch
.
Tensor
])
->
bool
:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if
not
has_deep_gemm
:
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if
N
<=
512
:
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
and
w2
.
is_contiguous
())
def
_fp8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
block_shape
:
Optional
[
List
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
return
A
,
A_scale
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
...
...
@@ -691,15 +753,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
...
...
@@ -1066,7 +1124,7 @@ def fused_topk(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
...
...
@@ -1098,14 +1156,16 @@ def fused_topk(
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
...
...
@@ -1154,10 +1214,11 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
)
->
Optional
[
str
]:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a16
:
...
...
@@ -1318,26 +1379,123 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
return
deep_gemm_moe_fp8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
else
:
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
def
_fp8_perm
(
m
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
A permutation routine that works on fp8 types.
"""
if
torch
.
is_floating_point
(
m
)
and
torch
.
finfo
(
m
.
dtype
).
bits
==
8
:
return
m
.
view
(
dtype
=
torch
.
uint8
)[
idx
,
...].
view
(
dtype
=
m
.
dtype
)
else
:
return
m
[
idx
,
...]
def
_moe_permute
(
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
curr_topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
top_k_num
:
int
,
block_m
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
block_m
,
global_num_experts
,
expert_map
,
pad_sorted_ids
=
True
))
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
num_tokens
=
top_k_num
*
tokens_in_chunk
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
_moe_unpermute_and_reduce
(
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
inv_perm
:
Optional
[
torch
.
Tensor
],
topk
:
int
,
K
:
int
,
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
=
topk_weight
.
shape
[
0
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
_resize_cache
(
x
:
torch
.
Tensor
,
v
:
Tuple
[
int
,
...])
->
torch
.
Tensor
:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert
prod
(
v
)
<=
x
.
numel
()
return
x
.
flatten
()[:
prod
(
v
)].
view
(
*
v
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -1376,6 +1534,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
...
...
@@ -1401,13 +1560,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
w2
.
shape
[
1
]
),
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
(
M
,
topk_ids
.
shape
[
1
],
N
))
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
w2
.
shape
[
1
]].
view
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]))
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
M
,
top_k_num
,
N
)
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
K
].
view
(
M
,
top_k_num
,
K
)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2
=
torch
.
empty
((
M
*
top_k_num
,
N
//
2
),
...
...
@@ -1452,14 +1609,23 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
else
:
qcurr_hidden_states
=
curr_hidden_states
a1q_scale
=
a1_scale
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
invoke_fused_moe_kernel
(
q
curr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
a1
q
_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
...
...
@@ -1485,10 +1651,19 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
else
:
qintermediate_cache2
=
intermediate_cache2
a2q_scale
=
a2_scale
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
a2
q
_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
...
...
@@ -1617,6 +1792,193 @@ def fused_moe(
block_shape
=
block_shape
)
def
deep_gemm_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
assert
expert_map
is
None
,
"Expert maps not supported yet"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
cache13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
intermediate_cache2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
cache13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
top_k_num
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
intermediate_cache1
=
_resize_cache
(
intermediate_cache1
,
(
curr_M
,
N
))
intermediate_cache2
=
_resize_cache
(
intermediate_cache2
,
(
curr_M
,
N
//
2
))
intermediate_cache3
=
_resize_cache
(
intermediate_cache3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
intermediate_cache1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qintermediate_cache2
,
a2q_scale
),
(
w2
,
w2_scale
),
intermediate_cache3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
inv_perm
,
top_k_num
,
K
,
curr_topk_weights
)
return
out_hidden_states
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
e59ca942
# SPDX-License-Identifier: Apache-2.0
import
importlib.util
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
def
_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
assert
x
.
dim
()
==
3
b
,
m
,
n
=
x
.
shape
return
x
.
stride
(
0
)
==
m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
m
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
...
...
@@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
# Check for DeepGemm support.
self
.
allow_deep_gemm
=
False
if
envs
.
VLLM_USE_DEEP_GEMM
:
if
not
has_deep_gemm
:
logger
.
warning_once
(
"Failed to import DeepGemm kernels."
)
elif
(
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)):
logger
.
info_once
(
"Using DeepGemm kernels for Fp8MoEMethod."
)
self
.
allow_deep_gemm
=
True
else
:
logger
.
warning_once
(
"DeepGemm not supported on the current platform."
)
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
...
@@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if
self
.
allow_deep_gemm
:
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
if
_is_col_major
(
layer
.
w13_weight_scale_inv
):
layer
.
w13_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale_inv
).
contiguous
()
if
_is_col_major
(
layer
.
w2_weight_scale_inv
):
layer
.
w2_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
return
# If checkpoint is fp16, quantize in place.
...
...
@@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
allow_deep_gemm
=
self
.
allow_deep_gemm
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment