Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2e57a90
Commit
d2e57a90
authored
Sep 30, 2025
by
王敏
Browse files
[feat]优化mori计算逻辑,支持cudagraph,按照bs*ep_size截断fused_moe的输入,共享专家不tp切分,去掉最后的allreduce
parent
8824ae6a
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
251 additions
and
135 deletions
+251
-135
vllm/config.py
vllm/config.py
+0
-3
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+5
-1
vllm/envs.py
vllm/envs.py
+3
-3
vllm/forward_context.py
vllm/forward_context.py
+2
-2
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
+5
-2
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+98
-87
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+25
-9
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+5
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+29
-1
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+58
-2
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+5
-5
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-14
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
vllm/config.py
View file @
d2e57a90
...
@@ -4320,9 +4320,6 @@ class CompilationConfig:
...
@@ -4320,9 +4320,6 @@ class CompilationConfig:
self
.
splitting_ops
=
[]
if
self
.
full_cuda_graph
else
[
self
.
splitting_ops
=
[]
if
self
.
full_cuda_graph
else
[
"vllm.unified_attention"
,
"vllm.unified_attention"
,
"vllm.unified_attention_with_output"
,
"vllm.unified_attention_with_output"
,
"vllm.token_permutation_forward"
,
"vllm.token_unpermutation_forward"
,
"vllm.ep_moe_forward"
,
]
]
...
...
vllm/distributed/parallel_state.py
View file @
d2e57a90
...
@@ -948,6 +948,10 @@ def init_distributed_environment(
...
@@ -948,6 +948,10 @@ def init_distributed_environment(
"Fallback Gloo backend is not available."
)
"Fallback Gloo backend is not available."
)
backend
=
"gloo"
backend
=
"gloo"
# this backend is used for WORLD
# this backend is used for WORLD
data_parallel_size
=
parallel_config
.
data_parallel_size
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
if
use_mori_ep
:
backend
=
"cpu:gloo,cuda:nccl"
backend
=
"cpu:gloo,cuda:nccl"
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
backend
=
backend
,
...
...
vllm/envs.py
View file @
d2e57a90
...
@@ -168,7 +168,7 @@ if TYPE_CHECKING:
...
@@ -168,7 +168,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_CAT
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
VLLM_USE_
ALLTOALL
_EP
:
bool
=
False
VLLM_USE_
MORI
_EP
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1112,8 +1112,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1112,8 +1112,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use all_to_all ep mode
# vLLM will use all_to_all ep mode
"VLLM_USE_
ALLTOALL
_EP"
:
"VLLM_USE_
MORI
_EP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
ALLTOALL
_EP"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_
MORI
_EP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
}
}
...
...
vllm/forward_context.py
View file @
d2e57a90
...
@@ -136,8 +136,8 @@ def set_forward_context(
...
@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
use_
all2all
_ep
=
envs
.
VLLM_USE_
ALLTOALL
_EP
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
use_
mori
_ep
=
envs
.
VLLM_USE_
MORI
_EP
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
if
not
use_
all2all
_ep
and
dp_size
>
1
and
(
if
not
use_
mori
_ep
and
dp_size
>
1
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
)
:
attn_metadata
is
not
None
or
num_tokens
is
not
None
)
:
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
attn_metadata
,
num_tokens
or
0
,
attn_metadata
,
num_tokens
or
0
,
...
...
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
View file @
d2e57a90
...
@@ -88,13 +88,16 @@ class EPSharedExperts(nn.Module):
...
@@ -88,13 +88,16 @@ class EPSharedExperts(nn.Module):
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
expect_tp_size
=
1
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
prefix
=
f
"
{
prefix
}
.down_proj"
,
expect_tp_size
=
1
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
d2e57a90
...
@@ -7,7 +7,6 @@ from collections.abc import Iterable
...
@@ -7,7 +7,6 @@ from collections.abc import Iterable
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -17,6 +16,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count
...
@@ -17,6 +16,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.distributed
import
expert_parallel_all_gather
,
expert_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
...
@@ -25,6 +25,10 @@ from vllm.utils import direct_register_custom_op
...
@@ -25,6 +25,10 @@ from vllm.utils import direct_register_custom_op
import
mori
import
mori
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -40,7 +44,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -40,7 +44,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
.
moe
=
moe
self
.
moe
=
moe
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
zero_token_count
=
None
def
apply_ep
(
def
apply_ep
(
self
,
self
,
...
@@ -235,9 +238,11 @@ class EPMoE(FusedMoE):
...
@@ -235,9 +238,11 @@ class EPMoE(FusedMoE):
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
self
.
mori_op
=
self
.
get_mori_op
()
self
.
scales
=
None
self
.
use_int8_dispatch
=
True
self
.
zero_token_count
=
None
self
.
mori_op
=
self
.
get_mori_op
()
self
.
first
=
True
def
get_mori_op
(
self
):
def
get_mori_op
(
self
):
...
@@ -253,20 +258,28 @@ class EPMoE(FusedMoE):
...
@@ -253,20 +258,28 @@ class EPMoE(FusedMoE):
mori
.
shmem
.
shmem_torch_process_group_init
(
"default"
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"default"
)
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
multi_node
=
self
.
ep_size
/
8
>
1
mori_data_type
=
vllm_config
.
model_config
.
dtype
mori_scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
if
self
.
use_int8_dispatch
:
mori_scale_type_size
=
4
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
data_type
=
vllm_config
.
model_config
.
d
type
,
data_type
=
mori_data_
type
,
rank
=
self
.
ep_rank
,
rank
=
self
.
ep_rank
,
world_size
=
self
.
ep_size
,
world_size
=
self
.
ep_size
,
hidden_dim
=
self
.
hidden_size
,
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
0
,
scale_dim
=
1
if
self
.
use_int8_dispatch
else
0
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
item
size
,
scale_type_size
=
mori_scale_type_
size
,
max_num_inp_token_per_rank
=
4096
,
max_num_inp_token_per_rank
=
2048
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
max_token_type_size
=
2
,
block_num
=
64
,
block_num
=
80
,
warp_num_per_block
=
16
,
warp_num_per_block
=
16
,
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
)
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
...
@@ -291,13 +304,11 @@ class EPMoE(FusedMoE):
...
@@ -291,13 +304,11 @@ class EPMoE(FusedMoE):
return
quant_method
return
quant_method
def
sync
(
self
):
def
sync
(
self
):
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
dist
.
barrier
()
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
):
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
)
...
@@ -322,9 +333,7 @@ class EPMoE(FusedMoE):
...
@@ -322,9 +333,7 @@ class EPMoE(FusedMoE):
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
topk_weights
,
topk_ids
=
self
.
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -337,25 +346,27 @@ class EPMoE(FusedMoE):
...
@@ -337,25 +346,27 @@ class EPMoE(FusedMoE):
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
indices_type
=
torch
.
int
64
,
indices_type
=
torch
.
int
32
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
use_fused_gate
=
self
.
use_fused_gate
)
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
if
self
.
use_int8_dispatch
:
scales
=
torch
.
rand
(
hidden_states
,
scales
=
per_token_quant_int8
(
hidden_states
)
else
:
if
self
.
scales
is
None
:
self
.
scales
=
torch
.
rand
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
0
],
0
,
0
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
scales
=
self
.
scales
#dist.barrier()
#self.sync()
(
(
dispatch_output
,
dispatch_output
,
...
@@ -369,49 +380,54 @@ class EPMoE(FusedMoE):
...
@@ -369,49 +380,54 @@ class EPMoE(FusedMoE):
scales
,
scales
,
topk_ids
,
topk_ids
,
)
)
#self.sync()
#self.sync()
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
expect_m
=
hidden_states
.
shape
[
0
]
*
self
.
ep_size
# #dispatch_recv_num_token = dispatch_recv_num_token.item()
dispatch_output_clip
=
dispatch_output
[:
expect_m
]
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
dispatch_weights_clip
=
dispatch_weights
[:
expect_m
]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
dispatch_indices_clip
=
dispatch_indices
[:
expect_m
]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
dispatch_scales_clip
=
dispatch_scales
[:
expect_m
]
# dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = torch.narrow(dispatch_output, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_weights = torch.narrow(dispatch_weights, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_indices = torch.narrow(dispatch_indices, dim=0, start=0, length=dispatch_recv_num_token)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# dispatch_output = dispatch_output[valid_mask]
# dispatch_indices = dispatch_indices[valid_mask]
# dispatch_weights = dispatch_weights[valid_mask]
# dispatch_recv_num_token = dispatch_indices.shape[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if
dispatch_recv_num_token
>
0
:
# Matrix multiply.
expert_output
=
self
.
quant_method
.
apply_ep
(
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
layer
=
self
,
x
=
dispatch_output
,
x
=
dispatch_output
_clip
,
topk_weights
=
dispatch_weights
,
topk_weights
=
dispatch_weights
_clip
,
topk_ids
=
dispatch_indices
,
topk_ids
=
dispatch_indices
_clip
,
global_num_experts
=
self
.
global_num_experts
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
use_nn_moe
=
self
.
use_nn_moe
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
],
scales
=
dispatch_scales_clip
if
self
.
use_int8_dispatch
else
None
#routed_scaling_factor=self.routed_scaling_factor,
)
)
else
:
expert_output
=
dispatch_output
#[:dispatch_recv_num_token]
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=dispatch_weights,
# topk_ids=dispatch_indices,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]*2,
# scales=dispatch_scales if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
#self.sync()
#self.sync()
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
...
@@ -422,9 +438,9 @@ class EPMoE(FusedMoE):
...
@@ -422,9 +438,9 @@ class EPMoE(FusedMoE):
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
# the token_dispatcher to overlap communications and computations
shared_output
=
(
#
shared_output = (
self
.
maybe_all_reduce_tensor_model_parallel
(
#
self.maybe_all_reduce_tensor_model_parallel(
shared_output
))
#
shared_output))
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
...
@@ -434,31 +450,26 @@ class EPMoE(FusedMoE):
...
@@ -434,31 +450,26 @@ class EPMoE(FusedMoE):
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
return
final_hidden_states
,
new_resi
else
:
return
final_hidden_states
,
None
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
layer_name
:
str
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
rms_weight
,
residual
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
def
ep_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
ep_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
layer_name
:
str
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
),
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"ep_moe_forward"
,
op_name
=
"ep_moe_forward"
,
op_func
=
ep_moe_forward
,
op_func
=
ep_moe_forward
,
mutates_args
=
[
"hidden_states"
,
"router_logits"
,
"rms_weight"
,
"residual"
],
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
fake_impl
=
ep_moe_forward_fake
,
fake_impl
=
ep_moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
d2e57a90
...
@@ -1257,13 +1257,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1257,13 +1257,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
,
num_local_tokens
,
true_bs
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1289,7 +1291,9 @@ def inplace_fused_experts_fake(
...
@@ -1289,7 +1291,9 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,)
->
None
:
pass
pass
...
@@ -1325,14 +1329,16 @@ def outplace_fused_experts(
...
@@ -1325,14 +1329,16 @@ def outplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
,
num_local_tokens
,
true_bs
)
def
outplace_fused_experts_fake
(
def
outplace_fused_experts_fake
(
...
@@ -1357,7 +1363,9 @@ def outplace_fused_experts_fake(
...
@@ -1357,7 +1363,9 @@ def outplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1414,7 +1422,9 @@ def fused_experts(
...
@@ -1414,7 +1422,9 @@ def fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# permute/unpermute ops are available.
N
=
w1
.
size
(
1
)
N
=
w1
.
size
(
1
)
...
@@ -1472,7 +1482,9 @@ def fused_experts(
...
@@ -1472,7 +1482,9 @@ def fused_experts(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
num_local_tokens
=
num_local_tokens
,
true_bs
=
true_bs
)
def
fused_experts_impl
(
def
fused_experts_impl
(
...
@@ -1500,6 +1512,8 @@ def fused_experts_impl(
...
@@ -1500,6 +1512,8 @@ def fused_experts_impl(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
true_bs
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
if
use_nn_moe
:
...
@@ -1544,7 +1558,9 @@ def fused_experts_impl(
...
@@ -1544,7 +1558,9 @@ def fused_experts_impl(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
False
use_nn_moe
=
False
,
num_local_tokens
=
num_local_tokens
,
true_bs
=
true_bs
,
)
)
elif
use_int4_w4a8
is
True
:
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
d2e57a90
...
@@ -152,7 +152,8 @@ def moe_align_block_size(
...
@@ -152,7 +152,8 @@ def moe_align_block_size(
num_experts
:
int
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
,
pad_sorted_ids
:
bool
=
False
,
num_token
:
Optional
[
int
]
=
None
num_token
:
Optional
[
int
]
=
None
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
...
@@ -234,7 +235,7 @@ def moe_align_block_size(
...
@@ -234,7 +235,7 @@ def moe_align_block_size(
if
envs
.
VLLM_USE_LIGHT_OP
:
if
envs
.
VLLM_USE_LIGHT_OP
:
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
op
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
None
)
expert_ids
,
num_tokens_post_pad
,
expert_map
,
None
,
num_local_tokens
)
else
:
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
)
...
...
vllm/model_executor/layers/linear.py
View file @
d2e57a90
...
@@ -486,9 +486,13 @@ class ColumnParallelLinear(LinearBase):
...
@@ -486,9 +486,13 @@ class ColumnParallelLinear(LinearBase):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
self
.
expect_tp_size
=
expect_tp_size
self
.
tp_size
=
self
.
expect_tp_size
self
.
input_size_per_partition
=
input_size
self
.
input_size_per_partition
=
input_size
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
...
@@ -728,10 +732,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -728,10 +732,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
):
self
.
eps
=
eps
self
.
eps
=
eps
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
output_size
=
sum
(
output_sizes
),
...
@@ -741,7 +753,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -741,7 +753,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
expect_tp_size
=
expect_tp_size
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -838,6 +851,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -838,6 +851,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
expect_tp_size
is
not
None
and
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_size
=
1
if
output_dim
is
not
None
:
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
...
@@ -1384,10 +1402,16 @@ class RowParallelLinear(LinearBase):
...
@@ -1384,10 +1402,16 @@ class RowParallelLinear(LinearBase):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
):
# Divide the weight matrix along the first dimension.
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
self
.
tp_rank
=
0
self
.
tp_size
=
1
self
.
expect_tp_size
=
expect_tp_size
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
self
.
output_partition_sizes
=
[
output_size
]
...
@@ -1433,6 +1457,10 @@ class RowParallelLinear(LinearBase):
...
@@ -1433,6 +1457,10 @@ class RowParallelLinear(LinearBase):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
expect_tp_size
is
not
None
:
tp_rank
=
0
tp_size
=
1
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
d2e57a90
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
os
import
os
import
torch
import
torch
import
vllm.envs
as
envs
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.paramet
er
import
Paramet
er
from
vllm.logg
er
import
init_logg
er
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
...
@@ -16,6 +18,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...
@@ -16,6 +18,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
ModelWeightParameter
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
)
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
...
@@ -23,6 +29,9 @@ except Exception:
...
@@ -23,6 +29,9 @@ except Exception:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
init_logger
(
__name__
)
class
MarlinMoeWorkspace
:
class
MarlinMoeWorkspace
:
"""
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...
@@ -220,6 +229,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -220,6 +229,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
num_local_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
config_select_bs
:
Optional
[
int
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
scales
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
...
@@ -243,6 +256,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -243,6 +256,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
num_local_tokens
=
num_local_tokens
,
config_select_bs
=
config_select_bs
,
q_scales
=
scales
)
)
def
apply
(
def
apply
(
...
@@ -309,3 +325,43 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -309,3 +325,43 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
)
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
BatchedGroupedGemmExperts
,
GroupedGemmGemmExperts
)
assert
not
self
.
rocm_aiter_moe_enabled
,
(
"ROCm AITER are not supported with all2all yet."
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
(
prepare_finalize
.
max_num_tokens_per_rank
())
assert
max_num_tokens_per_rank
is
not
None
logger
.
debug
(
"BatchedGroupedGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
self
.
quant_config
.
weight_block_size
,
False
)
return
BatchedGroupedGemmExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
use_fp8_w8a8
=
False
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
per_act_token_quant
=
True
,
allow_deep_gemm
=
False
,
)
else
:
logger
.
debug
(
"GroupedGemmGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
quant_config
.
weight_block_size
,
False
)
return
GroupedGemmGemmExperts
(
use_fp8_w8a8
=
False
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
allow_deep_gemm
=
False
,
)
vllm/model_executor/models/deepseek_mtp.py
View file @
d2e57a90
...
@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_
all2all
_ep
=
envs
.
VLLM_USE_
ALLTOALL
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
mori
_ep
=
envs
.
VLLM_USE_
MORI
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
forward
(
def
forward
(
...
@@ -211,7 +211,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -211,7 +211,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
...
@@ -244,7 +244,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -244,7 +244,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
@@ -261,7 +261,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -261,7 +261,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
...
@@ -273,7 +273,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -273,7 +273,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id
=
expert_id
)
expert_id
=
expert_id
)
break
break
else
:
else
:
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d2e57a90
...
@@ -165,9 +165,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -165,9 +165,9 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_
all2all
_ep
=
envs
.
VLLM_USE_
ALLTOALL
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
mori
_ep
=
envs
.
VLLM_USE_
MORI
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
moe_cls
=
FusedMoE
if
not
self
.
use_
all2all
_ep
else
EPMoE
moe_cls
=
FusedMoE
if
not
self
.
use_
mori
_ep
else
EPMoE
self
.
experts
=
moe_cls
(
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
...
@@ -189,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -189,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
#
shared_expert_cls = DeepseekV2MLP if not self.use_
all2all
_ep else EPSharedExperts
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_
mori
_ep
else
EPSharedExperts
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
...
@@ -199,7 +199,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -199,7 +199,7 @@ class DeepseekV2MoE(nn.Module):
),
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
...
@@ -212,7 +212,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -212,7 +212,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_
all2all
_ep
:
if
not
self
.
use_
mori
_ep
:
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
...
@@ -222,7 +222,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -222,7 +222,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
self
.
use_
all2all
_ep
:
if
not
self
.
use_
mori
_ep
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -233,10 +233,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -233,10 +233,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
else
:
else
:
final_hidden_states
,
new_resi
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
not
self
.
use_
all2all
_ep
:
if
not
self
.
use_
mori
_ep
:
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
...
@@ -917,7 +917,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -917,7 +917,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_
all2all
_ep
=
envs
.
VLLM_USE_
ALLTOALL
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
mori
_ep
=
envs
.
VLLM_USE_
MORI
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
set_eplb_state
(
def
set_eplb_state
(
self
,
self
,
...
@@ -1000,7 +1000,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -1000,7 +1000,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_keys
=
"mlp.shared_experts"
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
ep_moe_shared_experts_mapping
=
{
ep_moe_shared_experts_keys
:
"mlp.experts.shared_experts"
}
...
@@ -1037,7 +1037,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -1037,7 +1037,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
...
@@ -1066,7 +1066,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -1066,7 +1066,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
# Instead, create a new variable
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name_mapped
=
name_mapped
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name_mapped
=
name_mapped
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
if
is_pp_missing_parameter
(
name_mapped
,
self
):
if
is_pp_missing_parameter
(
name_mapped
,
self
):
...
@@ -1094,7 +1094,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -1094,7 +1094,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it
# So we simply skip it
continue
continue
if
self
.
use_
all2all
_ep
:
if
self
.
use_
mori
_ep
:
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
name
=
name
.
replace
(
ep_moe_shared_experts_keys
,
ep_moe_shared_experts_mapping
[
ep_moe_shared_experts_keys
])
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d2e57a90
...
@@ -320,7 +320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -320,7 +320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
self
.
use_
all2all
_ep
=
envs
.
VLLM_USE_
ALLTOALL
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_
mori
_ep
=
envs
.
VLLM_USE_
MORI
_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
...
@@ -1234,7 +1234,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1234,7 +1234,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
self
.
use_
all2all
_ep
:
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
self
.
use_
mori
_ep
:
# Early exit.
# Early exit.
return
0
,
None
return
0
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment