Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06a1bee2
Commit
06a1bee2
authored
Oct 13, 2025
by
zhuwenwen
Browse files
optimize the implementation of moe_sum
parent
b7989b07
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
2 deletions
+140
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+135
-2
No files found.
vllm/envs.py
View file @
06a1bee2
...
...
@@ -166,6 +166,7 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHTOP
:
bool
=
False
VLLM_USE_OPT_CAT
:
bool
=
False
VLLM_USE_OPT_MOE_SUM
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
...
...
@@ -1104,6 +1105,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_MOE_SUM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
06a1bee2
...
...
@@ -52,6 +52,136 @@ logger = init_logger(__name__)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
@
torch
.
compile
def
moe_sum_reduce_torch_compile
(
x
,
out
,
routed_scaling_factor
):
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
def
moe_sum_reduce_triton
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
if
token_num
<=
32
:
BLOCK_M
=
1
BLOCK_DIM
=
512
NUM_STAGE
=
2
num_warps
=
4
elif
token_num
<=
128
:
BLOCK_M
=
1
BLOCK_DIM
=
1024
NUM_STAGE
=
0
num_warps
=
2
elif
token_num
<=
4096
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
0
num_warps
=
2
else
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
2
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
def
moe_reduce_dispatch
(
intermediate_cache3
:
torch
.
Tensor
,
out_hidden_states
:
torch
.
Tensor
,
begin_chunk_idx
:
int
,
end_chunk_idx
:
int
,
):
inter_cache_view
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
)
n
=
intermediate_cache3
.
shape
[
0
]
# 根据 n 大小选择不同的 reduce 实现
if
1
<=
n
<=
4
:
moe_sum_reduce_torch_compile
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
4
<
n
<=
1024
:
moe_sum_reduce_triton
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
1024
<
n
<=
32768
:
ops
.
moe_sum_opt1
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
else
:
ops
.
moe_sum
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
...
...
@@ -1787,6 +1917,9 @@ def fused_experts_impl(
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else
:
if
envs
.
VLLM_USE_OPT_MOE_SUM
:
moe_reduce_dispatch
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
begin_chunk_idx
,
end_chunk_idx
)
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment