Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ae8f58c
Commit
1ae8f58c
authored
Dec 02, 2025
by
王敏
Browse files
[feat]支持deepep低延迟与共享专家overlap
parent
bca29c66
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
312 additions
and
76 deletions
+312
-76
vllm/config.py
vllm/config.py
+2
-1
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+1
-2
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+3
-2
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+54
-11
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+46
-3
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+43
-4
vllm/model_executor/layers/fused_moe/mori_moe/layer.py
vllm/model_executor/layers/fused_moe/mori_moe/layer.py
+3
-6
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+58
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+15
-6
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+83
-39
No files found.
vllm/config.py
View file @
1ae8f58c
...
@@ -4780,8 +4780,9 @@ class VllmConfig:
...
@@ -4780,8 +4780,9 @@ class VllmConfig:
# add for spec decode
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
mtp_
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
batch_size_capture_list
))
batch_size_capture_list
=
sorted
(
set
(
batch_size_capture_list
+
mtp_batch_size_capture_list
))
self
.
compilation_config
.
init_with_cudagraph_sizes
(
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
batch_size_capture_list
)
...
...
vllm/distributed/device_communicators/all2all.py
View file @
1ae8f58c
...
@@ -192,8 +192,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
...
@@ -192,8 +192,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes
=
num_rdma_bytes
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
False
,
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
,
num_qps_per_rank
=
num_qps_per_rank
,
explicitly_destroy
=
False
,
explicitly_destroy
=
False
)
use_default_stream_as_comm_stream
=
False
)
def
get_handle
(
self
,
kwargs
):
def
get_handle
(
self
,
kwargs
):
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
1ae8f58c
...
@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
...
@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
moe_modules
=
[
moe_modules
=
[
module
for
module
in
model
.
modules
()
module
for
module
in
model
.
modules
()
if
module
.
__class__
.
__name__
==
"FusedMoE"
if
(
module
.
__class__
.
__name__
==
"FusedMoE"
or
module
.
__class__
.
__name__
==
"SharedFusedMoE"
)
]
]
for
module
in
moe_modules
:
for
module
in
moe_modules
:
module
.
quant_method
.
init_prepare_finalize
(
module
.
moe_config
,
module
.
quant_method
.
init_prepare_finalize
(
module
,
module
.
moe_config
,
module
.
quant_config
)
module
.
quant_config
)
def
dispatch
(
def
dispatch
(
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
1ae8f58c
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEActivationFormat
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
FusedMoEPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
...
@@ -38,6 +39,7 @@ __all__ = [
...
@@ -38,6 +39,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize"
,
"FusedMoEPrepareAndFinalize"
,
"override_config"
,
"override_config"
,
"get_config"
,
"get_config"
,
"SharedFusedMoE"
,
]
]
if
HAS_TRITON
:
if
HAS_TRITON
:
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
1ae8f58c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
collections.abc
import
Callable
import
deep_ep
import
deep_ep
import
torch
import
torch
...
@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
buffer
:
deep_ep
.
Buffer
,
buffer
:
deep_ep
.
Buffer
,
max_tokens_per_rank
:
int
,
max_tokens_per_rank
:
int
,
num_dispatchers
:
int
,
num_dispatchers
:
int
,
use_fp8_dispatch
:
bool
=
False
):
use_fp8_dispatch
:
bool
=
False
,
use_int8_dispatch
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
buffer
=
buffer
self
.
buffer
=
buffer
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
use_fp8_dispatch
=
use_fp8_dispatch
self
.
use_fp8_dispatch
=
use_fp8_dispatch
self
.
use_int8_dispatch
=
use_int8_dispatch
# The dispatch function returns a handle that the combine function
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# requires. We store the handle here so it is available to the
# combine function.
# combine function.
...
@@ -154,7 +157,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -154,7 +157,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids
,
topk_ids
,
self
.
max_tokens_per_rank
,
self
.
max_tokens_per_rank
,
num_experts
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
use_int8
=
self
.
use_int8_dispatch
,
async_finish
=
False
,
async_finish
=
False
,
return_recv_hook
=
False
)
return_recv_hook
=
False
)
...
@@ -163,12 +167,18 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -163,12 +167,18 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
def
_finalize
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
self
,
apply_router_weight_on_input
:
bool
,
output
:
torch
.
Tensor
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
,
do_async
:
bool
,
)
->
Callable
:
do_recv_hook
=
do_async
assert
self
.
handle
is
not
None
assert
self
.
handle
is
not
None
combine_topk_weights
=
topk_weights
combine_topk_weights
=
topk_weights
...
@@ -177,12 +187,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -177,12 +187,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights
=
torch
.
ones_like
(
topk_weights
)
combine_topk_weights
=
torch
.
ones_like
(
topk_weights
)
# TODO (varun) : Enable zero copy mode
# TODO (varun) : Enable zero copy mode
_
,
event
,
hook
=
self
.
buffer
.
low_latency_combine
(
_
,
_
,
recv_
hook
=
self
.
buffer
.
low_latency_combine
(
fused_expert_output
,
fused_expert_output
,
topk_ids
,
topk_ids
,
combine_topk_weights
,
combine_topk_weights
,
self
.
handle
,
self
.
handle
,
async_finish
=
False
,
async_finish
=
False
,
zero_copy
=
False
,
zero_copy
=
False
,
return_recv_hook
=
False
,
return_recv_hook
=
do_recv_hook
,
out
=
output
)
out
=
output
,
)
return
recv_hook
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
return
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
,
do_async
=
True
,
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
,
do_async
=
False
,
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
1ae8f58c
...
@@ -92,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -92,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
raise
NotImplementedError
def
init_prepare_finalize
(
self
,
moe
:
FusedMoEConfig
,
def
init_prepare_finalize
(
self
,
layer
,
moe
:
FusedMoEConfig
,
quant_config
:
Optional
[
QuantizationConfig
]):
quant_config
:
Optional
[
QuantizationConfig
]):
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
assert
all2all_manager
is
not
None
...
@@ -170,6 +170,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -170,6 +170,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
==
current_platform
.
fp8_dtype
()
==
current_platform
.
fp8_dtype
()
and
moe
.
quant_config
.
block_shape
and
moe
.
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
==
DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch
=
False
#moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
# profiling. Turning it off for now.
...
@@ -178,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -178,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank
=
moe
.
max_num_tokens
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
use_fp8_dispatch
=
use_fp8_dispatch
,
use_int8_dispatch
=
use_int8_dispatch
,
)
)
self
.
topk_indices_dtype
=
None
self
.
topk_indices_dtype
=
None
...
@@ -195,6 +198,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -195,6 +198,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self
.
fused_experts
=
DeepGemmDisabledFusedMoEModularKernel
(
self
.
fused_experts
=
DeepGemmDisabledFusedMoEModularKernel
(
prepare_finalize
,
prepare_finalize
,
experts
,
experts
,
shared_experts
=
layer
.
shared_experts
if
hasattr
(
layer
,
"shared_experts"
)
else
None
,
)
)
def
select_gemm_impl
(
def
select_gemm_impl
(
...
@@ -912,6 +916,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -912,6 +916,10 @@ class FusedMoE(torch.nn.Module):
@
property
@
property
def
use_deepep_ll_kernels
(
self
):
def
use_deepep_ll_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
return
None
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
param
:
torch
.
nn
.
Parameter
,
...
@@ -1456,8 +1464,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -1456,8 +1464,12 @@ class FusedMoE(torch.nn.Module):
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
else
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
if
self
.
shared_experts
is
None
:
self
.
layer_name
,
shared_output
)
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
,
shared_output
)
else
:
return
torch
.
ops
.
vllm
.
moe_forward_shared
(
hidden_states
,
router_logits
,
self
.
layer_name
,
shared_output
)
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
):
full_router_logits
:
torch
.
Tensor
):
...
@@ -1667,4 +1679,35 @@ direct_register_custom_op(
...
@@ -1667,4 +1679,35 @@ direct_register_custom_op(
fake_impl
=
moe_forward_fake
,
fake_impl
=
moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
moe_forward_shared
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
shared_experts
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
)
def
moe_forward_shared_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shared_out
=
torch
.
empty_like
(
hidden_states
)
fused_out
=
torch
.
empty_like
(
hidden_states
)
return
shared_out
,
fused_out
direct_register_custom_op
(
op_name
=
"moe_forward_shared"
,
op_func
=
moe_forward_shared
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
moe_forward_shared_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
1ae8f58c
...
@@ -759,7 +759,21 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -759,7 +759,21 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids
,
apply_router_weight_on_input
)
topk_ids
,
apply_router_weight_on_input
)
return
output
return
output
_aux_stream
:
torch
.
cuda
.
Stream
|
None
=
None
def
aux_stream
()
->
torch
.
cuda
.
Stream
|
None
:
"""
Ensures aux_stream is initialized only once
"""
global
_aux_stream
# TODO: validate this works properly on ROCm platform.
if
_aux_stream
is
None
:
_aux_stream
=
torch
.
cuda
.
Stream
()
return
_aux_stream
@
final
@
final
class
DeepGemmDisabledFusedMoEModularKernel
(
torch
.
nn
.
Module
):
class
DeepGemmDisabledFusedMoEModularKernel
(
torch
.
nn
.
Module
):
...
@@ -779,10 +793,17 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
...
@@ -779,10 +793,17 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
fused_experts
:
CustomizedFusedMoEPermuteExpertsUnpermute
,
fused_experts
:
CustomizedFusedMoEPermuteExpertsUnpermute
,
shared_experts
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
self
.
prepare_finalize
=
prepare_finalize
self
.
fused_experts
=
fused_experts
self
.
fused_experts
=
fused_experts
self
.
shared_experts
=
shared_experts
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts_stream
=
aux_stream
()
self
.
shared_experts_overlap_event
=
torch
.
cuda
.
Event
()
# assert prepare_finalize.activation_format == \
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
# f"{prepare_finalize.__class__.__name__}."
...
@@ -849,7 +870,11 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
...
@@ -849,7 +870,11 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
"""
a1
=
hidden_states
a1
=
hidden_states
output
=
a1
if
inplace
else
torch
.
zeros_like
(
a1
)
if
inplace
and
self
.
shared_experts
is
None
:
output
=
hidden_states
else
:
output
=
torch
.
zeros_like
(
hidden_states
)
local_num_experts
=
w1
.
size
(
0
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
...
@@ -898,7 +923,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
...
@@ -898,7 +923,21 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
shared_output
=
None
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
if
self
.
shared_experts
is
None
:
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
else
:
hook
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
hook
is
not
None
:
hook
()
if
self
.
shared_experts
is
not
None
:
return
(
shared_output
,
output
)
return
output
return
output
vllm/model_executor/layers/fused_moe/mori_moe/layer.py
View file @
1ae8f58c
...
@@ -189,6 +189,8 @@ class MoriMoE(FusedMoE):
...
@@ -189,6 +189,8 @@ class MoriMoE(FusedMoE):
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
False
,
moe_permute_fusion
:
bool
=
False
,
shared_experts
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
use_overlapped
:
bool
=
False
,
):
):
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
intermediate_size
,
params_dtype
,
intermediate_size
,
params_dtype
,
...
@@ -214,6 +216,7 @@ class MoriMoE(FusedMoE):
...
@@ -214,6 +216,7 @@ class MoriMoE(FusedMoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
)
)
self
.
shared_experts
=
shared_experts
local_expert_indices_offset
=
(
local_expert_indices_offset
=
(
self
.
ep_rank
*
self
.
local_num_experts
self
.
ep_rank
*
self
.
local_num_experts
...
@@ -221,8 +224,6 @@ class MoriMoE(FusedMoE):
...
@@ -221,8 +224,6 @@ class MoriMoE(FusedMoE):
self
.
local_expert_indices
=
[
self
.
local_expert_indices
=
[
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
]
]
self
.
shared_experts
=
None
self
.
scales
=
None
self
.
scales
=
None
self
.
use_int8_dispatch
=
True
self
.
use_int8_dispatch
=
True
...
@@ -267,10 +268,6 @@ class MoriMoE(FusedMoE):
...
@@ -267,10 +268,6 @@ class MoriMoE(FusedMoE):
return
_MORI_OP
return
_MORI_OP
def
set_shared_experts
(
self
,
shared_experts
:
torch
.
nn
.
Module
):
if
self
.
shared_experts
is
None
:
self
.
shared_experts
=
shared_experts
def
create_quant_method
(
self
,
moe
,
quant_config
,
prefix
):
def
create_quant_method
(
self
,
moe
,
quant_config
,
prefix
):
# Note: get_quant_method will look at the layer's local_num_experts
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
# for heuristic purposes, so it must be initialized first.
...
...
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
0 → 100644
View file @
1ae8f58c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class
SharedFusedMoE
(
FusedMoE
):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def
__init__
(
self
,
shared_experts
:
torch
.
nn
.
Module
,
use_overlapped
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
_shared_experts
=
shared_experts
self
.
use_overlapped
=
use_overlapped
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
return
self
.
_shared_experts
if
self
.
use_overlapped
else
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
not
self
.
use_overlapped
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if
(
self
.
reduce_results
and
self
.
tp_size
>
1
and
self
.
must_reduce_shared_expert_outputs
()):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
# Matrix multiply.
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
return
fused_out
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
1ae8f58c
...
@@ -82,11 +82,13 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -82,11 +82,13 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
self
.
fused_experts
=
self
.
fused_moe_forward
self
.
fused_experts
=
self
.
fused_moe_forward
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
self
.
dp_size
=
get_dp_group
().
world_size
self
.
use_deepep
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
use_deepep_ll
=
self
.
use_deepep
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
if
self
.
use_deepep
:
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
assert
all2all_manager
is
not
None
...
@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
:
if
self
.
use_deepep
_ll
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
self
.
K
=
hidden_size
...
@@ -151,7 +153,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -151,7 +153,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
_ll
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
...
@@ -162,7 +164,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -162,7 +164,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del
w1_marlin_list
del
w1_marlin_list
w2_marlin_list
=
[]
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
_ll
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
...
@@ -236,7 +238,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -236,7 +238,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
# may lead to better performance.
expected_m
=
max_num_tokens
#expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
dp_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked
((
q_x
,
a1_scale
),
m_grouped_w8a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
(
w1
,
w1_scale
),
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
1ae8f58c
...
@@ -466,9 +466,9 @@ def apply_int8_linear(
...
@@ -466,9 +466,9 @@ def apply_int8_linear(
m_
=
m
m_
=
m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif
m
<=
64
:
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
m_
=
64
#
(m + 3) & -4 #取值到最近的4的倍数
elif
m
<=
160
:
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
m_
=
160
#
(m + 7) & -8
elif
m
<
200
:
#256
elif
m
<
200
:
#256
m_
=
160
m_
=
160
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
1ae8f58c
...
@@ -42,7 +42,8 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
...
@@ -42,7 +42,8 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.mori_moe.layer
import
MoriMoE
from
vllm.model_executor.layers.fused_moe.mori_moe.layer
import
MoriMoE
from
vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis
import
EPSharedExperts
from
vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis
import
EPSharedExperts
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -169,41 +170,72 @@ class DeepseekV2MoE(nn.Module):
...
@@ -169,41 +170,72 @@ class DeepseekV2MoE(nn.Module):
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
use_deepep_ll
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
MoriMoE
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
if
not
self
.
use_deepep_ll
:
top_k
=
config
.
num_experts_per_tok
,
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
MoriMoE
hidden_size
=
config
.
hidden_size
,
self
.
experts
=
moe_cls
(
intermediate_size
=
config
.
moe_intermediate_size
,
num_experts
=
config
.
n_routed_experts
,
reduce_results
=
False
,
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_mori_ep
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
use_grouped_topk
=
True
,
),
num_expert_group
=
config
.
n_group
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
topk_group
=
config
.
topk_group
,
)
prefix
=
f
"
{
prefix
}
.experts"
,
if
self
.
use_mori_ep
:
scoring_func
=
config
.
scoring_func
,
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_mori_ep
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
else
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_mori_ep
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
dp_size
!=
self
.
ep_size
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
self
.
experts
=
SharedFusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
shared_experts
=
self
.
shared_experts
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
...
@@ -215,7 +247,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -215,7 +247,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_mori_ep
:
if
not
self
.
use_mori_ep
and
not
self
.
use_deepep_ll
:
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
...
@@ -250,7 +282,22 @@ class DeepseekV2MoE(nn.Module):
...
@@ -250,7 +282,22 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
else
:
if
not
self
.
use_mori_ep
:
if
self
.
use_deepep_ll
:
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
elif
self
.
use_mori_ep
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
else
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
...
@@ -263,9 +310,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -263,9 +310,6 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
not
self
.
use_mori_ep
:
if
not
self
.
use_mori_ep
:
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment