Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a56d6cb
Commit
1a56d6cb
authored
Sep 16, 2025
by
王敏
Browse files
添加mori ep
parent
3cb11400
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
182 additions
and
245 deletions
+182
-245
vllm/config.py
vllm/config.py
+4
-4
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+105
-202
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+3
-8
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+13
-13
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+55
-11
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-6
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+1
-1
No files found.
vllm/config.py
View file @
1a56d6cb
...
...
@@ -2004,10 +2004,10 @@ class ParallelConfig:
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
if
self
.
enable_eplb
:
if
not
current_platform
.
is_cuda
():
raise
ValueError
(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
)
#
if not current_platform.is_cuda():
#
raise ValueError(
#
"Expert parallelism load balancing is only supported on "
#
"CUDA devices now.")
if
self
.
num_redundant_experts
<
0
:
raise
ValueError
(
"num_redundant_experts must be non-negative, but got "
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
1a56d6cb
...
...
@@ -2,6 +2,7 @@ import os
import
logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
collections.abc
import
Iterable
import
torch
import
torch.nn.functional
as
F
...
...
@@ -11,7 +12,7 @@ from vllm.platforms import current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
,
is_use_cuda_graph
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
...
...
@@ -20,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
lightop
import
groupgemm
import
mori
import
torch.distributed
as
dist
...
...
@@ -45,7 +45,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
...
...
@@ -59,7 +58,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return
self
.
forward
(
hidden_states
=
hidden_states
,
layer
=
layer
,
tokens_per_expert
=
tokens_per_expert
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
...
...
@@ -73,7 +71,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
...
...
@@ -85,80 +82,24 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
)
->
torch
.
Tensor
:
# process MoE
def
custom_forward
(
layer
,
hidden_states
,
tokens_per_expert
):
if
False
:
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
w1
=
layer
.
w13_weight
[
i
]
w2
=
layer
.
w2_weight
[
i
]
tokens_for_this_expert
=
hidden_states
[
start_idx
:
end_idx
]
gateup_output
=
torch
.
matmul
(
tokens_for_this_expert
,
w1
)
# Act
down_input
=
torch
.
zeros
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states
.
dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
w1
.
shape
[
1
]))
expert_out
=
torch
.
matmul
(
down_input
,
w2
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
if
len
(
outputs
)
>
0
:
expert_output
=
torch
.
cat
(
outputs
,
dim
=
0
)
else
:
assert
hidden_states
.
numel
()
==
0
,
f
"sorted_tokens: should be empty, but got
{
hidden_states
.
shape
}
"
expert_output
=
hidden_states
else
:
if
topk_ids
is
None
:
if
self
.
zero_token_count
is
None
:
self
.
zero_token_count
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
total_tokens
=
tokens_per_expert
.
sum
()
print
(
"#################total_tokens:"
,
total_tokens
.
tolist
())
if
total_tokens
>
self
.
zero_token_count
:
gateup_output
=
groupgemm
(
hidden_states
,
layer
.
w13_weight
,
tokens_per_expert
,
False
)
# Act
down_input
=
torch
.
zeros
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
hidden_states
.
dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
layer
.
w13_weight
.
shape
[
2
]))
expert_output
=
groupgemm
(
down_input
,
layer
.
w2_weight
,
tokens_per_expert
,
False
)
else
:
expert_output
=
hidden_states
else
:
expert_output
=
self
.
fused_experts
(
hidden_states
=
hidden_states
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
)
def
custom_forward
(
layer
,
hidden_states
):
expert_output
=
self
.
fused_experts
(
hidden_states
=
hidden_states
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
)
return
expert_output
output
=
custom_forward
(
layer
,
hidden_states
,
tokens_per_expert
)
output
=
custom_forward
(
layer
,
hidden_states
)
return
output
...
...
@@ -166,7 +107,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
...
...
@@ -183,7 +123,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
...
...
@@ -199,7 +138,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
True
,
...
...
@@ -249,7 +187,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
Tru
e
,
moe_permute_fusion
:
bool
=
Fals
e
,
moe_shared_expert_overlap
:
bool
=
False
):
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
...
...
@@ -296,8 +234,10 @@ class EPMoE(FusedMoE):
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
if
True
:
self
.
mori_op
=
self
.
get_mori_op
()
self
.
mori_op
=
self
.
get_mori_op
()
self
.
zero_token_count
=
None
def
get_mori_op
(
self
):
global
_MORI_OP
...
...
@@ -319,7 +259,7 @@ class EPMoE(FusedMoE):
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
0
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
max_num_inp_token_per_rank
=
4096
,
max_num_inp_token_per_rank
=
20480
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
...
...
@@ -334,6 +274,7 @@ class EPMoE(FusedMoE):
def
set_shared_experts
(
self
,
shared_experts
:
torch
.
nn
.
Module
):
if
self
.
shared_experts
is
None
:
self
.
shared_experts
=
shared_experts
if
self
.
shared_expert_overlap
:
self
.
token_dispatcher
.
set_shared_experts
(
self
.
shared_experts
)
...
...
@@ -355,8 +296,28 @@ class EPMoE(FusedMoE):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
)
def
get_expert_weights
(
self
)
->
Iterable
[
torch
.
Tensor
]:
weights
=
list
(
self
.
named_parameters
())
# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
# with shape (num_logical_experts,), not an expert weight.
NON_EXPERT_WEIGHTS
=
{
"e_score_correction_bias"
,
"shared_experts.gate_up_proj.weight"
,
"shared_experts.gate_up_proj.weight_scale"
,
"shared_experts.down_proj.weight"
,
"shared_experts.down_proj.weight_scale"
}
return
[
weight
.
view
(
self
.
local_num_experts
,
-
1
)
for
name
,
weight
in
weights
if
name
not
in
NON_EXPERT_WEIGHTS
]
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
...
...
@@ -376,129 +337,71 @@ class EPMoE(FusedMoE):
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
scales
=
torch
.
rand
(
hidden_states
.
shape
[
0
],
0
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
)
########################test#########################
# probs = None
# if self.apply_router_weight_on_input:
# probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
# routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
# (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
# torch.cuda.synchronize()
# print("###########################all2all dispatch_output shape:", dispatch_output.shape)
# print("###########################all2all dispatch_output:", dispatch_output[:10, :10])
# expert_output = self.quant_method.apply_ep(
# layer=self,
# hidden_states=dispatch_output,
# tokens_per_expert=tokens_per_expert,
# topk_weights=None,
# topk_ids=None,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# )
# torch.cuda.synchronize()
# print("###########################grouped gemm out:", expert_output[:10, :10])
# final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
# final_hidden_states_all2all = final_hidden_states
# torch.cuda.synchronize()
# print("####################all2all unpermute output:", final_hidden_states[:10, :10].tolist())
########################test##########################
if
False
:
probs
=
None
if
self
.
apply_router_weight_on_input
:
probs
=
torch
.
zeros_like
(
router_logits
,
dtype
=
topk_weights
.
dtype
).
scatter
(
1
,
topk_ids
,
topk_weights
)
routing_map
=
torch
.
zeros_like
(
router_logits
).
int
().
scatter
(
1
,
topk_ids
,
1
).
bool
()
(
dispatch_output
,
tokens_per_expert
)
=
self
.
token_dispatcher
.
token_permutation
(
hidden_states
,
probs
,
routing_map
)
else
:
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
scales
=
torch
.
rand
(
hidden_states
.
shape
[
0
],
0
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
)
self
.
sync
()
print
(
"##########################topk_weights shape:{} topk_ids shape:{}"
.
format
(
topk_weights
.
shape
,
topk_ids
.
shape
))
(
dispatch_output
,
dispatch_weights
,
dispatch_scales
,
dispatch_indices
,
dispatch_recv_num_token
,
)
=
self
.
mori_op
.
dispatch
(
hidden_states
.
contiguous
(),
topk_weights
.
contiguous
(),
scales
.
contiguous
(),
topk_ids
.
contiguous
(),
)
self
.
sync
()
# with torch.inference_mode():
# src_token_pos = self.mori_op.get_dispatch_src_token_pos().tolist()
# print("##################src_token_pos:", src_token_pos[:10].tolist())
tokens_per_expert
=
dispatch_recv_num_token
print
(
"######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}"
.
format
(
dispatch_output
.
shape
,
dispatch_weights
.
shape
,
dispatch_indices
.
shape
))
print
(
"####################dispatch_recv_num_token:"
,
dispatch_recv_num_token
)
dispatch_recv_num_token
=
dispatch_recv_num_token
.
cpu
()[
0
]
print
(
"########################dispatch_output:"
,
dispatch_output
[:
10
,
:
10
].
tolist
())
print
(
"########################dispatch_indices:"
,
dispatch_indices
[:
10
,
:].
tolist
())
print
(
"#########################start fused_moe"
)
has_greater_than_255
=
torch
.
any
(
dispatch_indices
>
255
).
item
()
has_less_than_0
=
torch
.
any
(
dispatch_indices
<
0
).
item
()
print
(
"##################################has_greater_than_255:{} has_less_than_0:{}"
.
format
(
has_greater_than_255
,
has_less_than_0
))
if
dispatch_recv_num_token
>
0
:
# Matrix multiply.
#expert_output = self.quant_method.apply_ep(
expert_output
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
dispatch_output
[:
dispatch_recv_num_token
].
contiguous
(),
tokens_per_expert
=
tokens_per_expert
,
topk_weights
=
dispatch_weights
[:
dispatch_recv_num_token
].
contiguous
(),
topk_ids
=
dispatch_indices
[:
dispatch_recv_num_token
].
contiguous
(),
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
)
else
:
expert_output
=
dispatch_output
[:
dispatch_recv_num_token
]
self
.
sync
()
print
(
"####################fused_moe expert_output:"
,
expert_output
[:
10
,
:
10
].
tolist
())
(
dispatch_output
,
dispatch_weights
,
dispatch_scales
,
dispatch_indices
,
dispatch_recv_num_token
,
)
=
self
.
mori_op
.
dispatch
(
hidden_states
,
topk_weights
,
scales
,
topk_ids
,
)
if
False
:
final_hidden_states
=
self
.
token_dispatcher
.
token_unpermutation
(
expert_output
)
#self.sync()
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
dispatch_recv_num_token
=
dispatch_recv_num_token
.
cpu
()[
0
]
dispatch_output
=
dispatch_output
[:
dispatch_recv_num_token
]
dispatch_weights
=
dispatch_weights
[:
dispatch_recv_num_token
]
dispatch_indices
=
dispatch_indices
[:
dispatch_recv_num_token
]
valid_mask
=
((
dispatch_indices
<=
255
)
&
(
dispatch_indices
>=
0
)).
all
(
dim
=
1
)
dispatch_output
=
dispatch_output
[
valid_mask
]
dispatch_indices
=
dispatch_indices
[
valid_mask
]
dispatch_weights
=
dispatch_weights
[
valid_mask
]
dispatch_recv_num_token
=
dispatch_indices
.
shape
[
0
]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if
dispatch_recv_num_token
>
0
:
# Matrix multiply.
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
x
=
dispatch_output
,
topk_weights
=
dispatch_weights
,
topk_ids
=
dispatch_indices
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
)
else
:
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
.
contiguous
(),
dispatch_weights
.
contiguous
(),
topk_ids
.
contiguous
())
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
torch
.
cuda
.
synchronize
()
print
(
"####################mori combine_output:"
,
combine_output
[:
10
,
:
10
].
tolist
())
expert_output
=
dispatch_output
#[:dispatch_recv_num_token]
#self.sync()
self
.
sync
()
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
####################test#################
# final_hidden_states_close = torch.allclose(final_hidden_states, final_hidden_states_all2all, rtol=1e-2, atol=1e-2)
# print(f"final_hidden_states_close: {final_hidden_states_close}")
#####################test################
#self.sync()
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
...
...
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
View file @
1a56d6cb
...
...
@@ -331,7 +331,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
if
Tru
e
:
if
Fals
e
:
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
hidden_states
,
routing_map
,
...
...
@@ -339,15 +339,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused
=
self
.
config
.
moe_permute_fusion
)
else
:
torch
.
cuda
.
synchronize
()
print
(
"########################hidden_states shape:{}
\n
#####################routing_map shape:{}
\n
"
.
format
(
hidden_states
.
shape
,
routing_map
.
shape
))
print
(
"########################hidden_states:{}
\n
#####################routing_map:{}
\n
"
.
format
(
hidden_states
[
0
,
:
10
].
tolist
(),
routing_map
[
0
,
:
10
].
tolist
()))
cuda_permute_result
=
groupgemm_permute
(
hidden_states
,
routing_map
)
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
\
expert_m
,
self
.
expert_m_count
,
expert_m_max
=
cuda_permute_result
self
.
expert_m_count
=
cuda_permute_result
# Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
...
...
@@ -427,7 +422,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
shared_experts
.
post_forward_comm
()
# Unpermutation 1: AlltoAll output to output
if
Tru
e
:
if
Fals
e
:
output
=
unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
1a56d6cb
...
...
@@ -778,19 +778,19 @@ class FusedMoE(torch.nn.Module):
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
self
.
quant_method
=
quant_method
if
self
.
enable_eplb
:
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8MoEMethod
)
if
not
isinstance
(
quant_method
,
Fp8MoEMethod
):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise
NotImplementedError
(
"EPLB is only supported for FP8 "
"quantization for now."
)
#
if self.enable_eplb:
#
from vllm.model_executor.layers.quantization.fp8 import (
#
Fp8MoEMethod)
#
if not isinstance(quant_method, Fp8MoEMethod):
#
# TODO: Add support for additional quantization methods.
#
# The implementation for other quantization methods does not
#
# contain essential differences, but the current quant API
#
# design causes duplicated work when extending to new
#
# quantization methods, so I'm leaving it for now.
#
# If you plan to add support for more quantization methods,
#
# please refer to the implementation in `Fp8MoEMethod`.
#
raise NotImplementedError("EPLB is only supported for FP8 "
#
"quantization for now.")
if
quant_config
is
None
:
# Not considering quant for now, temporarily
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
1a56d6cb
...
...
@@ -334,29 +334,59 @@ class SlimQuantW4A8Int8MoEMethod:
def
apply_ep
(
#dp+ep
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
return
fused_experts_impl_w4a8_ep
(
hidden_states
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
tokens_per_expert
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
def
apply
(
# tp
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -364,6 +394,20 @@ class SlimQuantW4A8Int8MoEMethod:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
x
,
...
...
vllm/platforms/rocm.py
View file @
1a56d6cb
...
...
@@ -102,12 +102,7 @@ def with_amdsmi_context(fn):
def
device_id_to_physical_device_id
(
device_id
:
int
)
->
int
:
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
device_ids
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
].
split
(
","
)
physical_device_id
=
device_ids
[
device_id
]
return
int
(
physical_device_id
)
else
:
return
device_id
return
device_id
@
cache
...
...
vllm/v1/spec_decode/eagle.py
View file @
1a56d6cb
...
...
@@ -441,7 +441,7 @@ class EagleProposer:
# [batch_size]
num_accepted_tokens_tensor
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cu_num_tokens
=
torch
.
arange
(
cu_target_query_lens
.
shape
[
0
],
device
=
cu_target_query_lens
.
device
)
cu_num_tokens
=
torch
.
arange
(
cu_target_query_lens
.
shape
[
0
],
device
=
cu_target_query_lens
.
device
,
dtype
=
torch
.
int32
)
token_indices
=
num_accepted_tokens_tensor
+
cu_target_query_lens
[:
-
1
]
return
cu_num_tokens
,
token_indices
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment