Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b35a518a
Commit
b35a518a
authored
Oct 13, 2025
by
zhuwenwen
Browse files
update moe_sum and moe_align
parent
2cf181fd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
166 additions
and
4 deletions
+166
-4
vllm/envs.py
vllm/envs.py
+15
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+141
-2
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+10
-2
No files found.
vllm/envs.py
View file @
b35a518a
...
@@ -232,6 +232,9 @@ if TYPE_CHECKING:
...
@@ -232,6 +232,9 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHTOP
:
bool
=
False
VLLM_USE_LIGHTOP
:
bool
=
False
VLLM_USE_OPT_CAT
:
bool
=
False
VLLM_USE_OPT_CAT
:
bool
=
False
VLLM_USE_OPT_MOE_SUM
:
bool
=
False
VLLM_USE_LIGHTOP_MOE_SUM
:
bool
=
False
VLLM_USE_LIGHTOP_MOE_ALIGN
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
...
@@ -1625,6 +1628,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1625,6 +1628,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT"
:
"VLLM_USE_OPT_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_CAT"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_CAT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_MOE_SUM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_SUM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_ALIGN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states,not triton
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
b35a518a
...
@@ -58,6 +58,137 @@ logger = init_logger(__name__)
...
@@ -58,6 +58,137 @@ logger = init_logger(__name__)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
moe_cache_singleton
=
None
@
torch
.
compile
def
moe_sum_reduce_torch_compile
(
x
,
out
,
routed_scaling_factor
):
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
def
moe_sum_reduce_triton
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
if
token_num
<=
32
:
BLOCK_M
=
1
BLOCK_DIM
=
512
NUM_STAGE
=
2
num_warps
=
4
elif
token_num
<=
128
:
BLOCK_M
=
1
BLOCK_DIM
=
1024
NUM_STAGE
=
0
num_warps
=
2
elif
token_num
<=
4096
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
0
num_warps
=
2
else
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
2
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
def
moe_reduce_dispatch
(
intermediate_cache3
:
torch
.
Tensor
,
out_hidden_states
:
torch
.
Tensor
,
begin_chunk_idx
:
int
,
end_chunk_idx
:
int
,
):
inter_cache_view
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
)
n
=
intermediate_cache3
.
shape
[
0
]
# 根据 n 大小选择不同的 reduce 实现
if
1
<=
n
<=
4
:
moe_sum_reduce_torch_compile
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
4
<
n
<=
1024
:
moe_sum_reduce_triton
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
1024
<
n
<=
32768
:
ops
.
moe_sum_opt1
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
else
:
ops
.
moe_sum
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
if
moe_cache_singleton
is
None
:
...
@@ -2046,6 +2177,14 @@ def fused_experts_impl(
...
@@ -2046,6 +2177,14 @@ def fused_experts_impl(
B_bias
=
w2_bias
,
B_bias
=
w2_bias
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
None
,
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
1.0
)
elif
envs
.
VLLM_USE_OPT_MOE_SUM
:
moe_reduce_dispatch
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
begin_chunk_idx
,
end_chunk_idx
)
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
b35a518a
...
@@ -102,6 +102,14 @@ def moe_align_block_size(
...
@@ -102,6 +102,14 @@ def moe_align_block_size(
expert_map
=
expert_map
,
expert_map
=
expert_map
,
expert_mask
=
expert_mask
,
expert_mask
=
expert_mask
,
num_local_tokens
=
None
)
num_local_tokens
=
None
)
else
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_ALIGN
:
from
lightop
import
op
as
op
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
expert_map
=
None
,
expert_mask
=
None
,
num_local_tokens
=
None
)
else
:
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment