Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98b7432a
Commit
98b7432a
authored
Oct 27, 2025
by
王敏
Browse files
[feat]w4a8适配deepep ht模式,解决开启dp时mtp>1时卡住问题
parent
1f4b9553
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
565 additions
and
108 deletions
+565
-108
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+15
-5
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-1
vllm/envs.py
vllm/envs.py
+1
-6
vllm/forward_context.py
vllm/forward_context.py
+2
-2
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-0
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+11
-2
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+2
-1
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+3
-47
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+15
-7
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+305
-0
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
.../model_executor/layers/fused_moe/triton_group_gemm_moe.py
+107
-0
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+69
-24
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+21
-6
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+5
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-4
No files found.
vllm/distributed/device_communicators/all2all.py
View file @
98b7432a
...
...
@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self
.
num_sms
=
20
self
.
num_sms
=
24
#
20
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
...
...
@@ -166,13 +166,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def
_make_all2all_kwargs
(
self
)
->
dict
[
Any
,
Any
]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes
=
1024
*
1024
*
1024
num_nvl_bytes
=
int
(
2e9
/
2
)
#
1024 * 1024 * 1024
num_rdma_bytes
=
None
num_qps_per_rank
=
None
if
self
.
internode
:
num_rdma_bytes
=
1024
*
1024
*
1024
num_qps_per_rank
=
self
.
num_sms
//
2
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_qps_per_rank
=
30
#self.num_sms // 2
import
deep_ep
num_nvl_bytes
,
num_rdma_bytes
=
0
,
0
hidden_size
=
7168
hidden_bytes
=
hidden_size
*
2
for
config
in
(
deep_ep
.
Buffer
.
get_dispatch_config
(
self
.
cpu_group
.
size
()),
deep_ep
.
Buffer
.
get_combine_config
(
self
.
cpu_group
.
size
())):
num_nvl_bytes
=
max
(
config
.
get_nvl_buffer_size_hint
(
hidden_bytes
,
self
.
cpu_group
.
size
()),
num_nvl_bytes
)
num_rdma_bytes
=
max
(
config
.
get_rdma_buffer_size_hint
(
hidden_bytes
,
self
.
cpu_group
.
size
()),
num_rdma_bytes
)
else
:
num_rdma_bytes
=
0
num_qps_per_rank
=
1
...
...
@@ -183,7 +191,9 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes
=
num_nvl_bytes
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
)
num_qps_per_rank
=
num_qps_per_rank
,
explicitly_destroy
=
False
,
use_default_stream_as_comm_stream
=
False
)
def
get_handle
(
self
,
kwargs
):
...
...
vllm/distributed/parallel_state.py
View file @
98b7432a
...
...
@@ -951,7 +951,7 @@ def init_distributed_environment(
parallel_config
=
config
.
parallel_config
data_parallel_size
=
parallel_config
.
data_parallel_size
use_mori_ep
=
envs
.
VLLM_
USE_MORI_EP
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
use_mori_ep
=
envs
.
VLLM_
ALL2ALL_BACKEND
==
'mori'
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
if
use_mori_ep
:
backend
=
"cpu:gloo,cuda:nccl"
torch
.
distributed
.
init_process_group
(
...
...
vllm/envs.py
View file @
98b7432a
...
...
@@ -173,7 +173,6 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
VLLM_USE_MORI_EP
:
bool
=
False
VLLM_P2P_ASYNC
:
bool
=
False
VLLM_P2P_BUF_TOKENS
:
int
=
30000
...
...
@@ -945,6 +944,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use mori kernels
"VLLM_ALL2ALL_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ALL2ALL_BACKEND"
,
"naive"
),
...
...
@@ -1144,11 +1144,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use all_to_all ep mode
"VLLM_USE_MORI_EP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MORI_EP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_P2P_ASYNC"
,
"0"
))),
...
...
vllm/forward_context.py
View file @
98b7432a
...
...
@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
use_
mori
_ep
=
envs
.
VLLM_
USE_MORI_EP
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
if
not
use_mori
_ep
and
dp_size
>
1
and
(
use_
navie
_ep
=
envs
.
VLLM_
ALL2ALL_BACKEND
==
'naive'
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
if
use_navie
_ep
and
dp_size
>
1
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
)
:
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
attn_metadata
,
num_tokens
or
0
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
98b7432a
...
...
@@ -59,6 +59,8 @@ if HAS_TRITON:
get_config_file_name
,
grouped_topk
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.triton_group_gemm_moe
import
(
TritonOrGroupGemmExperts
)
__all__
+=
[
"fused_moe"
,
...
...
@@ -75,4 +77,5 @@ if HAS_TRITON:
"BatchedDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"BatchedTritonOrDeepGemmExperts"
,
"TritonOrGroupGemmExperts"
,
]
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
98b7432a
...
...
@@ -4,12 +4,15 @@ from typing import Optional
import
deep_ep
import
torch
import
torch.distributed
as
dist
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
(
moe_kernel_quantize_input
)
from
vllm.distributed.parallel_state
import
get_ep_group
class
DeepEPHTPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
...
...
@@ -54,6 +57,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if
self
.
dp_size
not
in
self
.
available_rank_configs
:
return
None
return
deep_ep
.
Buffer
.
get_combine_config
(
self
.
dp_size
)
def
sync
(
self
):
# torch.cuda.synchronize()
dist
.
barrier
()
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
Optional
[
torch
.
Tensor
],
...
...
@@ -205,13 +212,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
)
->
None
:
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
assert
self
.
handle
is
not
None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if
fused_expert_output
.
numel
()
!=
0
:
if
fused_expert_output
.
numel
()
!=
0
and
apply_weights_and_reduce
:
fused_expert_output
=
self
.
_apply_weights_and_reduce
(
num_tokens
=
topk_ids
.
size
(
0
),
fused_expert_output
=
fused_expert_output
,
...
...
@@ -227,5 +235,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
98b7432a
...
...
@@ -162,7 +162,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
)
->
None
:
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
assert
self
.
handle
is
not
None
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
98b7432a
...
...
@@ -230,19 +230,11 @@ class EPMoE(FusedMoE):
]
self
.
use_shared_expert
=
False
<<<<<<<
HEAD
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
=======
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
,
layer_name
=
f
"
{
self
.
layer_name
}
.token_dispatcher"
,
)
>>>>>>>
origin
/
v0
.
9.2
-
dev
-
ds
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
self
.
shared_experts
=
None
...
...
@@ -250,15 +242,9 @@ class EPMoE(FusedMoE):
self
.
use_int8_dispatch
=
True
vllm_config
=
get_current_vllm_config
()
self
.
max_num_inp_token_per_rank
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_inp_token_per_rank
=
1024
#
vllm_config.scheduler_config.max_num_seqs
self
.
mori_op
=
self
.
get_mori_op
()
<<<<<<<
HEAD
=======
self
.
first
=
True
>>>>>>>
origin
/
v0
.
9.2
-
dev
-
ds
def
get_mori_op
(
self
):
global
_MORI_OP
if
_MORI_OP
is
None
:
...
...
@@ -291,7 +277,7 @@ class EPMoE(FusedMoE):
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
block_num
=
80
,
warp_num_per_block
=
16
,
warp_num_per_block
=
4
,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
...
...
@@ -319,11 +305,7 @@ class EPMoE(FusedMoE):
return
quant_method
def
sync
(
self
):
<<<<<<<
HEAD
torch
.
cuda
.
synchronize
()
=======
# torch.cuda.synchronize()
>>>>>>>
origin
/
v0
.
9.2
-
dev
-
ds
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -398,33 +380,6 @@ class EPMoE(FusedMoE):
topk_ids
,
#layer_idx=int(self.layer_name.split('.')[2])
)
<<<<<<<
HEAD
#self.sync()
=======
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
>>>>>>>
origin
/
v0
.
9.2
-
dev
-
ds
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
...
...
@@ -441,6 +396,7 @@ class EPMoE(FusedMoE):
scales
=
dispatch_scales
if
self
.
use_int8_dispatch
else
None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
98b7432a
...
...
@@ -596,6 +596,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
num_tokens
=
topk_ids
.
size
(
0
)
num_local_experts
=
fused_expert_output
.
size
(
0
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
98b7432a
...
...
@@ -28,8 +28,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig
,
FusedMoEParallelConfig
)
# yapf: enable
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEModularKernel
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
FusedMoEActivationFormat
,
FusedMoEModularKernel
,
DeepGemmBannedFusedMoEModularKernel
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -40,7 +41,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
,
has_deep_gemm
from
vllm
import
_custom_ops
as
ops
...
...
@@ -184,10 +185,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger
.
debug
(
"%s"
,
prepare_finalize
.
__class__
.
__name__
)
self
.
topk_indices_dtype
=
prepare_finalize
.
topk_indices_dtype
()
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
moe
)
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
if
has_deep_gemm
():
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
else
:
self
.
fused_experts
=
DeepGemmBannedFusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
def
select_gemm_impl
(
self
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
98b7432a
...
...
@@ -149,6 +149,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
"""
Perform any combine plus apply weights and perform a reduction on the
...
...
@@ -355,6 +356,168 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
assigned to each expert when using batched experts format input.
"""
raise
NotImplementedError
class
CustomizedFusedMoEPermuteExpertsUnpermute
(
ABC
):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
def
__init__
(
self
,
quant_config
:
Optional
[
FusedMoEQuantConfig
],
):
if
quant_config
is
not
None
:
self
.
quant_config
=
quant_config
else
:
self
.
quant_config
=
FusedMoEQuantConfig
()
@
property
@
abstractmethod
def
activation_formats
(
self
)
->
tuple
[
FusedMoEActivationFormat
,
FusedMoEActivationFormat
]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise
NotImplementedError
@
property
def
quant_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
return
self
.
quant_config
.
quant_dtype
@
property
def
block_shape
(
self
)
->
Optional
[
list
[
int
]]:
return
self
.
quant_config
.
block_shape
@
property
def
per_act_token_quant
(
self
)
->
bool
:
return
self
.
quant_config
.
per_act_token_quant
@
property
def
per_out_ch_quant
(
self
)
->
bool
:
return
self
.
quant_config
.
per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@
abstractmethod
def
supports_chunking
(
self
)
->
bool
:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise
NotImplementedError
@
abstractmethod
def
supports_expert_map
(
self
)
->
bool
:
"""
A flag indicating whether or not this class supports expert maps
"""
raise
NotImplementedError
@
abstractmethod
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise
NotImplementedError
def
activation
(
self
,
activation
:
str
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
output
,
input
)
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
output
,
input
)
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
def
enable_chunking
(
self
):
return
envs
.
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING
and
\
self
.
supports_chunking
()
@
abstractmethod
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
"""
raise
NotImplementedError
def
_chunk_scales
(
scales
:
Optional
[
torch
.
Tensor
],
start
:
int
,
...
...
@@ -596,3 +759,145 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids
,
apply_router_weight_on_input
)
return
output
@
final
class
DeepGemmBannedFusedMoEModularKernel
(
torch
.
nn
.
Module
):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def
__init__
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
fused_experts
:
CustomizedFusedMoEPermuteExpertsUnpermute
,
):
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
self
.
fused_experts
=
fused_experts
assert
prepare_finalize
.
activation_format
==
\
fused_experts
.
activation_formats
[
0
],
(
f
"
{
prepare_finalize
.
__class__
.
__name__
}
."
f
"
{
prepare_finalize
.
activation_format
}
== "
f
"
{
fused_experts
.
__class__
.
__name__
}
."
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1
=
hidden_states
output
=
a1
if
inplace
else
torch
.
zeros_like
(
a1
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
(
a1q
,
a1q_scale
,
expert_num_tokens
,
_expert_topk_ids
,
_expert_topk_weights
)
=
self
.
prepare_finalize
.
prepare
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
global_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids
=
topk_ids
if
_expert_topk_ids
is
None
else
_expert_topk_ids
topk_weights
=
(
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
fused_out
=
self
.
fused_experts
.
apply
(
None
,
a1q
,
w1
,
w2
,
topk_ids
,
topk_weights
=
topk_weights
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
workspace13
=
None
,
workspace2
=
None
,
use_nn_moe
=
use_nn_moe
,
expert_num_tokens
=
expert_num_tokens
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
return
output
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
98b7432a
...
...
@@ -207,6 +207,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
98b7432a
...
...
@@ -61,6 +61,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
_moe_unpermute_and_reduce
(
output
,
fused_expert_output
,
None
,
topk_weights
,
apply_router_weight_on_input
)
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
0 → 100644
View file @
98b7432a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
_valid_deep_gemm
,
_valid_deep_gemm_shape
)
class
TritonOrGroupGemmExperts
(
mk
.
CustomizedFusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
allow_group_gemm
:
bool
=
False
,
fused_experts
=
None
):
super
().
__init__
(
FusedMoEQuantConfig
.
make
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
))
self
.
fused_experts
=
fused_experts
@
property
def
activation_formats
(
self
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
raise
NotImplementedError
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
assert
self
.
fused_experts
is
not
None
return
self
.
fused_experts
(
x
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
False
,
activation
=
activation
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
expert_num_tokens
=
expert_num_tokens
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
98b7432a
...
...
@@ -4,10 +4,12 @@ import os
import
torch
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
,
get_dp_group
from
vllm.logger
import
init_logger
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -125,6 +127,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
property
def
weight_block_size
(
self
):
return
[
128
,
128
]
class
SlimQuantW4A8Int8MarlinMoEMethod
:
...
...
@@ -154,6 +160,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
fused_experts
=
self
.
w4a8_marlin_forward
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
use_deepep
=
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
def
create_weights
(
self
,
...
...
@@ -218,6 +231,50 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
w4a8_marlin_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
,
w1
,
w2
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
apply_ep
(
#dp+ep
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -301,29 +358,25 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
torch
.
int64
if
self
.
use_deepep
else
None
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
...
...
@@ -335,10 +388,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
BatchedGroupedGemmExperts
,
GroupedGemmGemmExperts
)
assert
not
self
.
rocm_aiter_moe_enabled
,
(
"ROCm AITER are not supported with all2all yet."
)
TritonOrGroupGemmExperts
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
...
...
@@ -350,21 +400,16 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
self
.
quant_config
.
weight_block_size
,
False
)
return
BatchedGroupedGemmExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
use_fp8_w8a8
=
False
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
per_act_token_quant
=
True
,
allow_deep_gemm
=
False
,
)
return
None
else
:
logger
.
debug
(
"
GroupedGemm
GemmExperts(%s): block_size=%s, per_act_token=%s"
,
"
TritonOrGroup
GemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
quant_config
.
weight_block_size
,
False
)
return
GroupedGemmGemmExperts
(
return
TritonOrGroupGemmExperts
(
use_fp8_w8a8
=
False
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
allow_deep_gemm
=
False
,
allow_group_gemm
=
False
,
fused_experts
=
self
.
w4a8_marlin_forward
)
vllm/model_executor/models/deepseek_mtp.py
View file @
98b7432a
...
...
@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
envs
.
VLLM_
USE_MORI_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_mori_ep
=
envs
.
VLLM_
ALL2ALL_BACKEND
==
'mori'
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
forward
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
98b7432a
...
...
@@ -167,7 +167,8 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_mori_ep
=
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
EPMoE
self
.
experts
=
moe_cls
(
...
...
@@ -226,10 +227,24 @@ class DeepseekV2MoE(nn.Module):
if
not
self
.
use_mori_ep
:
if
envs
.
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
if
self
.
enable_expert_parallel
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
...
...
@@ -927,7 +942,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
envs
.
VLLM_
USE_MORI_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_mori_ep
=
envs
.
VLLM_
ALL2ALL_BACKEND
==
'mori'
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
set_eplb_state
(
self
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
98b7432a
...
...
@@ -513,8 +513,11 @@ class EagleProposer:
self
.
hidden_states
[:
num_tokens
],
)
if
self
.
dp_size
>
1
and
self
.
enable_expert_parallel
and
self
.
num_speculative_tokens
>
1
:
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
if
self
.
dp_size
>
1
and
self
.
enable_expert_parallel
and
self
.
num_speculative_tokens
>
1
:
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
98b7432a
...
...
@@ -323,9 +323,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
self
.
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
...
...
@@ -1238,7 +1235,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
self
.
use_mori_ep
:
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
==
'naive'
:
# Early exit.
return
0
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment