Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3cb11400
Commit
3cb11400
authored
Sep 10, 2025
by
王敏
Browse files
临时添加mori代码
parent
22a4e07b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
227 additions
and
87 deletions
+227
-87
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+189
-47
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+35
-14
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+3
-26
No files found.
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
3cb11400
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
lightop
import
groupgemm
from
lightop
import
groupgemm
#
import mori
import
mori
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -46,18 +46,42 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -46,18 +46,42 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
return
self
.
forward
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
layer
=
layer
,
layer
=
layer
,
tokens_per_expert
=
tokens_per_expert
)
tokens_per_expert
=
tokens_per_expert
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# process MoE
# process MoE
...
@@ -97,23 +121,39 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -97,23 +121,39 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
assert
hidden_states
.
numel
()
==
0
,
f
"sorted_tokens: should be empty, but got
{
hidden_states
.
shape
}
"
assert
hidden_states
.
numel
()
==
0
,
f
"sorted_tokens: should be empty, but got
{
hidden_states
.
shape
}
"
expert_output
=
hidden_states
expert_output
=
hidden_states
else
:
else
:
if
self
.
zero_token_count
is
None
:
if
topk_ids
is
None
:
self
.
zero_token_count
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
if
self
.
zero_token_count
is
None
:
total_tokens
=
tokens_per_expert
.
sum
()
self
.
zero_token_count
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
if
total_tokens
>
self
.
zero_token_count
:
total_tokens
=
tokens_per_expert
.
sum
()
gateup_output
=
groupgemm
(
hidden_states
,
layer
.
w13_weight
,
tokens_per_expert
,
False
)
print
(
"#################total_tokens:"
,
total_tokens
.
tolist
())
# Act
if
total_tokens
>
self
.
zero_token_count
:
down_input
=
torch
.
zeros
(
gateup_output
=
groupgemm
(
hidden_states
,
layer
.
w13_weight
,
tokens_per_expert
,
False
)
gateup_output
.
shape
[
0
],
# Act
gateup_output
.
shape
[
1
]
//
2
,
down_input
=
torch
.
zeros
(
device
=
gateup_output
.
device
,
gateup_output
.
shape
[
0
],
dtype
=
hidden_states
.
dtype
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states
.
dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
layer
.
w13_weight
.
shape
[
2
]))
expert_output
=
groupgemm
(
down_input
,
layer
.
w2_weight
,
tokens_per_expert
,
False
)
else
:
expert_output
=
hidden_states
else
:
expert_output
=
self
.
fused_experts
(
hidden_states
=
hidden_states
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
)
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
layer
.
w13_weight
.
shape
[
2
]))
expert_output
=
groupgemm
(
down_input
,
layer
.
w2_weight
,
tokens_per_expert
,
False
)
else
:
expert_output
=
hidden_states
return
expert_output
return
expert_output
...
@@ -127,6 +167,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -127,6 +167,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
**
kwargs
,
**
kwargs
,
):
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -136,6 +184,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -136,6 +184,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -144,6 +200,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -144,6 +200,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -224,7 +288,7 @@ class EPMoE(FusedMoE):
...
@@ -224,7 +288,7 @@ class EPMoE(FusedMoE):
self
.
use_shared_expert
=
False
self
.
use_shared_expert
=
False
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
local_num_experts
,
self
.
local_expert_indices
,
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
,
#
layer_name=f"{self.layer_name}.token_dispatcher",
config
=
self
.
ep_moe_config
,
layer_name
=
f
"
{
self
.
layer_name
}
.token_dispatcher"
,
)
)
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
...
@@ -232,7 +296,7 @@ class EPMoE(FusedMoE):
...
@@ -232,7 +296,7 @@ class EPMoE(FusedMoE):
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
if
Fals
e
:
if
Tru
e
:
self
.
mori_op
=
self
.
get_mori_op
()
self
.
mori_op
=
self
.
get_mori_op
()
def
get_mori_op
(
self
):
def
get_mori_op
(
self
):
...
@@ -240,8 +304,13 @@ class EPMoE(FusedMoE):
...
@@ -240,8 +304,13 @@ class EPMoE(FusedMoE):
if
_MORI_OP
is
None
:
if
_MORI_OP
is
None
:
# world_group = torch.distributed.group.WORLD
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# assert world_group is not None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group
=
torch
.
distributed
.
group
.
WORLD
assert
world_group
is
not
None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"default"
,
world_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"default"
)
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
data_type
=
vllm_config
.
model_config
.
dtype
,
data_type
=
vllm_config
.
model_config
.
dtype
,
...
@@ -250,12 +319,12 @@ class EPMoE(FusedMoE):
...
@@ -250,12 +319,12 @@ class EPMoE(FusedMoE):
hidden_dim
=
self
.
hidden_size
,
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
0
,
scale_dim
=
0
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
max_num_inp_token_per_rank
=
10000
,
max_num_inp_token_per_rank
=
4096
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
4
,
max_token_type_size
=
2
,
#
block_num=4
0
,
block_num
=
6
4
,
#
warp_num_per_block=
8
,
warp_num_per_block
=
16
,
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
)
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
...
@@ -307,14 +376,50 @@ class EPMoE(FusedMoE):
...
@@ -307,14 +376,50 @@ class EPMoE(FusedMoE):
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
True
:
########################test#########################
# probs = None
# if self.apply_router_weight_on_input:
# probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
# routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
# (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
# torch.cuda.synchronize()
# print("###########################all2all dispatch_output shape:", dispatch_output.shape)
# print("###########################all2all dispatch_output:", dispatch_output[:10, :10])
# expert_output = self.quant_method.apply_ep(
# layer=self,
# hidden_states=dispatch_output,
# tokens_per_expert=tokens_per_expert,
# topk_weights=None,
# topk_ids=None,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# )
# torch.cuda.synchronize()
# print("###########################grouped gemm out:", expert_output[:10, :10])
# final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
# final_hidden_states_all2all = final_hidden_states
# torch.cuda.synchronize()
# print("####################all2all unpermute output:", final_hidden_states[:10, :10].tolist())
########################test##########################
if
False
:
probs
=
None
probs
=
None
if
self
.
apply_router_weight_on_input
:
if
self
.
apply_router_weight_on_input
:
probs
=
torch
.
zeros_like
(
router_logits
,
dtype
=
topk_weights
.
dtype
).
scatter
(
1
,
topk_ids
,
topk_weights
)
probs
=
torch
.
zeros_like
(
router_logits
,
dtype
=
topk_weights
.
dtype
).
scatter
(
1
,
topk_ids
,
topk_weights
)
routing_map
=
torch
.
zeros_like
(
router_logits
).
int
().
scatter
(
1
,
topk_ids
,
1
).
bool
()
routing_map
=
torch
.
zeros_like
(
router_logits
).
int
().
scatter
(
1
,
topk_ids
,
1
).
bool
()
(
dispatch
ed_in
put
,
tokens_per_expert
)
=
self
.
token_dispatcher
.
token_permutation
(
(
dispatch
_out
put
,
tokens_per_expert
)
=
self
.
token_dispatcher
.
token_permutation
(
hidden_states
,
probs
,
routing_map
hidden_states
,
probs
,
routing_map
)
)
else
:
else
:
...
@@ -325,39 +430,76 @@ class EPMoE(FusedMoE):
...
@@ -325,39 +430,76 @@ class EPMoE(FusedMoE):
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
self
.
sync
()
print
(
"##########################topk_weights shape:{} topk_ids shape:{}"
.
format
(
topk_weights
.
shape
,
topk_ids
.
shape
))
(
(
dispatch
ed_in
put
,
dispatch
_out
put
,
dispatch_weights
,
dispatch_weights
,
dispatch_scales
,
dispatch_scales
,
dispatch_indices
,
dispatch_indices
,
dispatch_recv_num_token
,
dispatch_recv_num_token
,
)
=
self
.
mori_op
.
dispatch
(
)
=
self
.
mori_op
.
dispatch
(
hidden_states
,
hidden_states
.
contiguous
()
,
topk_weights
,
topk_weights
.
contiguous
()
,
scales
,
scales
.
contiguous
()
,
topk_ids
,
topk_ids
.
contiguous
()
,
)
)
tokens_per_expert
=
dispatch_recv_num_token
self
.
sync
()
self
.
sync
()
print
(
"######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}"
.
format
(
dispatched_input
.
shape
,
# with torch.inference_mode():
# src_token_pos = self.mori_op.get_dispatch_src_token_pos().tolist()
# print("##################src_token_pos:", src_token_pos[:10].tolist())
tokens_per_expert
=
dispatch_recv_num_token
print
(
"######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}"
.
format
(
dispatch_output
.
shape
,
dispatch_weights
.
shape
,
dispatch_indices
.
shape
))
dispatch_weights
.
shape
,
dispatch_indices
.
shape
))
print
(
"####################dispatch_recv_num_token:"
,
dispatch_recv_num_token
.
tolist
())
print
(
"####################dispatch_recv_num_token:"
,
dispatch_recv_num_token
)
#print("####################dispatch_weights:", dispatch_weights.tolist())
#print("####################dispatch_indices:", dispatch_indices.tolist())
dispatch_recv_num_token
=
dispatch_recv_num_token
.
cpu
()[
0
]
print
(
"########################dispatch_output:"
,
dispatch_output
[:
10
,
:
10
].
tolist
())
# Matrix multiply.
print
(
"########################dispatch_indices:"
,
dispatch_indices
[:
10
,
:].
tolist
())
expert_output
=
self
.
quant_method
.
apply_ep
(
print
(
"#########################start fused_moe"
)
layer
=
self
,
has_greater_than_255
=
torch
.
any
(
dispatch_indices
>
255
).
item
()
hidden_states
=
dispatched_input
,
has_less_than_0
=
torch
.
any
(
dispatch_indices
<
0
).
item
()
tokens_per_expert
=
tokens_per_expert
print
(
"##################################has_greater_than_255:{} has_less_than_0:{}"
.
format
(
has_greater_than_255
,
has_less_than_0
))
)
if
dispatch_recv_num_token
>
0
:
# Matrix multiply.
#expert_output = self.quant_method.apply_ep(
expert_output
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
dispatch_output
[:
dispatch_recv_num_token
].
contiguous
(),
tokens_per_expert
=
tokens_per_expert
,
topk_weights
=
dispatch_weights
[:
dispatch_recv_num_token
].
contiguous
(),
topk_ids
=
dispatch_indices
[:
dispatch_recv_num_token
].
contiguous
(),
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
)
else
:
expert_output
=
dispatch_output
[:
dispatch_recv_num_token
]
self
.
sync
()
print
(
"####################fused_moe expert_output:"
,
expert_output
[:
10
,
:
10
].
tolist
())
if
True
:
if
False
:
final_hidden_states
=
self
.
token_dispatcher
.
token_unpermutation
(
expert_output
)
final_hidden_states
=
self
.
token_dispatcher
.
token_unpermutation
(
expert_output
)
else
:
else
:
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
.
contiguous
()
,
dispatch_weights
.
contiguous
(),
topk_ids
.
contiguous
()
)
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
torch
.
cuda
.
synchronize
()
print
(
"####################mori combine_output:"
,
combine_output
[:
10
,
:
10
].
tolist
())
self
.
sync
()
####################test#################
# final_hidden_states_close = torch.allclose(final_hidden_states, final_hidden_states_all2all, rtol=1e-2, atol=1e-2)
# print(f"final_hidden_states_close: {final_hidden_states_close}")
#####################test################
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
# the token_dispatcher to overlap communications and computations
...
...
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
View file @
3cb11400
...
@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
...
@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
from
lightop
import
groupgemm_permute
,
groupgemm_unpermute
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
...
@@ -329,12 +330,24 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -329,12 +330,24 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
hidden_states
,
if
True
:
routing_map
,
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
num_out_tokens
=
self
.
num_out_tokens
,
hidden_states
,
fused
=
self
.
config
.
moe_permute_fusion
routing_map
,
)
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
)
else
:
torch
.
cuda
.
synchronize
()
print
(
"########################hidden_states shape:{}
\n
#####################routing_map shape:{}
\n
"
.
format
(
hidden_states
.
shape
,
routing_map
.
shape
))
print
(
"########################hidden_states:{}
\n
#####################routing_map:{}
\n
"
.
format
(
hidden_states
[
0
,
:
10
].
tolist
(),
routing_map
[
0
,
:
10
].
tolist
()))
cuda_permute_result
=
groupgemm_permute
(
hidden_states
,
routing_map
)
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
\
expert_m
,
self
.
expert_m_count
,
expert_m_max
=
cuda_permute_result
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
...
@@ -414,14 +427,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -414,14 +427,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
shared_experts
.
post_forward_comm
()
self
.
shared_experts
.
post_forward_comm
()
# Unpermutation 1: AlltoAll output to output
# Unpermutation 1: AlltoAll output to output
output
=
unpermute
(
if
True
:
permutated_local_input_tokens
,
output
=
unpermute
(
self
.
reversed_local_input_permutation_mapping
,
permutated_local_input_tokens
,
restore_shape
=
self
.
hidden_shape_before_permute
,
self
.
reversed_local_input_permutation_mapping
,
probs
=
self
.
probs
,
restore_shape
=
self
.
hidden_shape_before_permute
,
routing_map
=
self
.
routing_map
,
probs
=
self
.
probs
,
fused
=
self
.
config
.
moe_permute_fusion
,
routing_map
=
self
.
routing_map
,
)
fused
=
self
.
config
.
moe_permute_fusion
,
)
else
:
output
=
groupgemm_unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
list
(
self
.
hidden_shape_before_permute
),
self
.
probs
,
self
.
routing_map
,
self
.
expert_m_count
)
# Reshape the output tensor
# Reshape the output tensor
output
=
output
.
view
(
self
.
hidden_shape
)
output
=
output
.
view
(
self
.
hidden_shape
)
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
3cb11400
...
@@ -349,23 +349,14 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -349,23 +349,14 @@ class SlimQuantW4A8Int8MoEMethod:
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
top_k
:
int
,
topk_weights
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_ids
:
torch
.
Tensor
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -373,20 +364,6 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -373,20 +364,6 @@ class SlimQuantW4A8Int8MoEMethod:
raise
NotImplementedError
(
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment