Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d3ae2fc
Commit
0d3ae2fc
authored
Dec 16, 2025
by
yangql
Browse files
up auto deepep
parent
94c4ca4d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
259 additions
and
7 deletions
+259
-7
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+53
-0
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+4
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+9
-0
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
...executor/layers/fused_moe/deepep_auto_prepare_finalize.py
+132
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+43
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+6
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+12
-4
No files found.
vllm/distributed/device_communicators/all2all.py
View file @
0d3ae2fc
...
...
@@ -273,3 +273,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# in get_or_create must be updated.
handle
.
set_num_sms
(
self
.
num_sms
)
return
handle
class
DeepEPAutoAll2AllManager
(
All2AllManagerBase
):
"""
Simplified auto manager that always builds handles through the
low-latency DeepEP manager. This avoids creating multiple buffer
instances and mirrors the sglang behavior of relying on LL buffers.
"""
def
__init__
(
self
,
cpu_group
):
super
().
__init__
(
cpu_group
)
self
.
ll_manager
=
DeepEPLLAll2AllManager
(
cpu_group
)
self
.
ht_manager
=
DeepEPHTAll2AllManager
(
cpu_group
)
def
get_handle
(
self
,
kwargs
):
"""
Build a DeepEP Buffer using LL args but sized to the larger of HT/LL
requirements (max of num_nvl_bytes/num_rdma_bytes).
"""
import
deep_ep
kwargs
=
dict
(
kwargs
)
# Build canonical kwargs for each path.
ll_kwargs
=
self
.
ll_manager
.
_make_all2all_kwargs
(
**
kwargs
)
ht_kwargs
=
self
.
ht_manager
.
_make_all2all_kwargs
()
# Take the max for buffer sizes to be compatible with both modes.
merged_kwargs
=
dict
(
ll_kwargs
)
merged_kwargs
[
"num_nvl_bytes"
]
=
max
(
ll_kwargs
[
"num_nvl_bytes"
],
ht_kwargs
[
"num_nvl_bytes"
])
merged_kwargs
[
"num_rdma_bytes"
]
=
max
(
ll_kwargs
[
"num_rdma_bytes"
],
ht_kwargs
[
"num_rdma_bytes"
])
logger
.
debug
(
"DeepEP auto merged args %s"
,
merged_kwargs
)
handle
:
deep_ep
.
Buffer
=
self
.
ll_manager
.
handle_cache
.
get_or_create
(
merged_kwargs
,
deep_ep
.
Buffer
)
handle
.
set_num_sms
(
self
.
ll_manager
.
num_sms
)
return
handle
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
raise
NotImplementedError
(
"DeepEPAutoAll2AllManager does not support dispatch directly; "
"use the underlying HT/LL managers."
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"DeepEPAutoAll2AllManager does not support combine directly; "
"use the underlying HT/LL managers."
)
def
destroy
(
self
):
self
.
ll_manager
.
destroy
()
vllm/distributed/device_communicators/cuda_communicator.py
View file @
0d3ae2fc
...
...
@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from
.all2all
import
DeepEPLLAll2AllManager
self
.
all2all_manager
=
DeepEPLLAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using DeepEP Low-Latency all2all manager."
)
elif
all2all_backend
==
"deepep_auto"
:
from
.all2all
import
DeepEPAutoAll2AllManager
self
.
all2all_manager
=
DeepEPAutoAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using DeepEP Auto all2all manager."
)
elif
all2all_backend
==
"mori"
:
pass
else
:
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
0d3ae2fc
...
...
@@ -187,6 +187,11 @@ class FusedMoEParallelConfig:
return
(
self
.
use_all2all_kernels
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
@
property
def
use_deepep_auto_kernels
(
self
):
return
(
self
.
use_all2all_kernels
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
@
staticmethod
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
vllm_parallel_config
:
ParallelConfig
)
->
"FusedMoEParallelConfig"
:
...
...
@@ -385,6 +390,10 @@ class FusedMoEConfig:
def
use_deepep_ll_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
@
property
def
use_deepep_auto_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_auto_kernels
@
staticmethod
def
make
(
num_experts
:
int
,
...
...
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
0 → 100644
View file @
0d3ae2fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.forward_context
import
get_forward_context
class
DeepEPAutoPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def
__init__
(
self
,
ht_prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
ll_prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
):
super
().
__init__
()
self
.
ht_prepare_finalize
=
ht_prepare_finalize
self
.
ll_prepare_finalize
=
ll_prepare_finalize
self
.
_current_phase
=
"decode"
# default to prefill (HT)
def
_get_current_prepare_finalize
(
self
)
->
mk
.
FusedMoEPrepareAndFinalize
:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available
try
:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
# Handle both v0 (single AttentionMetadata) and v1 (dict) formats
if
isinstance
(
attn_metadata
,
dict
):
if
attn_metadata
:
attn_metadata
=
next
(
iter
(
attn_metadata
.
values
()))
else
:
attn_metadata
=
None
if
attn_metadata
is
not
None
and
hasattr
(
attn_metadata
,
'num_prefill_tokens'
)
and
hasattr
(
attn_metadata
,
'num_decode_tokens'
):
# Only use prefill mode when BOTH conditions are met:
# 1. There are prefill tokens and no decode tokens
# 2. skip_cuda_graphs is True
is_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
>
0
and
attn_metadata
.
num_decode_tokens
==
0
skip_cuda_graphs
=
forward_context
.
skip_cuda_graphs
# Only use prefill (HT) when both conditions are satisfied
self
.
_current_phase
=
"prefill"
if
(
is_prefill_tokens
and
skip_cuda_graphs
)
else
"decode"
except
Exception
:
# If forward_context is not available, use stored phase
pass
# Prefill uses HT, decode uses LL
# print("self._current_phase",self._current_phase)
# if self._current_phase == "prefill":
# return self.ht_prepare_finalize
# else:
return
self
.
ll_prepare_finalize
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
# Use the current prepare_finalize's activation format
# Note: HT uses Standard, LL uses BatchedExperts
# Dynamically return based on current phase
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
activation_format
def
topk_indices_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
# Both HT and LL return int64
return
torch
.
int64
def
max_num_tokens_per_rank
(
self
)
->
Optional
[
int
]:
# LL has a limit, HT returns None
return
self
.
ll_prepare_finalize
.
max_num_tokens_per_rank
()
def
num_dispatchers
(
self
)
->
int
:
# Both should return the same value
return
self
.
ht_prepare_finalize
.
num_dispatchers
()
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Route prepare call to the appropriate implementation."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
prepare
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
"""Route finalize call to the appropriate implementation."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
):
"""Route finalize_async call to the appropriate implementation if available."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
if
hasattr
(
prepare_finalize
,
'finalize_async'
):
return
prepare_finalize
.
finalize_async
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
else
:
# Fallback to synchronous finalize
return
prepare_finalize
.
finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0d3ae2fc
...
...
@@ -55,6 +55,7 @@ if current_platform.is_cuda_alike():
from
.deepep_ht_prepare_finalize
import
DeepEPHTPrepareAndFinalize
from
.deepep_ll_prepare_finalize
import
(
DEEPEP_QUANT_BLOCK_SHAPE
,
DeepEPLLPrepareAndFinalize
)
from
.deepep_auto_prepare_finalize
import
DeepEPAutoPrepareAndFinalize
else
:
fused_experts
=
None
# type: ignore
FusedMoEPermuteExpertsUnpermute
=
None
# type: ignore
...
...
@@ -140,6 +141,48 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_local_experts
=
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
)
elif
moe
.
use_deepep_auto_kernels
:
# Initialize both HT and LL prepare_finalize but reuse the single
# LL handle for both (sglang-style single handle)
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
ll_all_to_all_args
=
dict
(
max_num_tokens_per_dp_rank
=
moe
.
max_num_tokens
,
token_hidden_size
=
moe
.
hidden_dim
,
num_ep_ranks
=
all2all_manager
.
world_size
,
num_global_experts
=
moe
.
num_experts
,
num_local_experts
=
moe
.
num_experts
//
all2all_manager
.
world_size
,
)
ll_handle
=
all2all_manager
.
get_handle
(
ll_all_to_all_args
)
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
ll_handle
,
num_dispatchers
=
all2all_manager
.
world_size
,
dp_size
=
all2all_manager
.
dp_world_size
,
rank_expert_offset
=
all2all_manager
.
rank
*
moe
.
num_local_experts
,
)
use_fp8_dispatch
=
(
moe
.
quant_config
is
not
None
and
moe
.
quant_config
.
quant_dtype
==
current_platform
.
fp8_dtype
()
and
moe
.
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch
=
False
ll_prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
ll_handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
use_int8_dispatch
=
use_int8_dispatch
,
)
prepare_finalize
=
DeepEPAutoPrepareAndFinalize
(
ht_prepare_finalize
,
ll_prepare_finalize
)
elif
moe
.
use_deepep_ht_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
0d3ae2fc
...
...
@@ -84,11 +84,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
ep_size
=
get_ep_group
().
world_size
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
(
backend
==
"deepep_high_throughput"
or
\
backend
==
"deepep_low_latency"
or
\
backend
==
"deepep_auto"
)
self
.
use_deepep_ll
=
self
.
use_deepep
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
self
.
use_deepep_ll
=
self
.
use_deepep
and
(
backend
==
"deepep_low_latency"
or
\
(
backend
==
"deepep_auto"
))
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0d3ae2fc
...
...
@@ -174,8 +174,12 @@ class DeepseekV2MoE(nn.Module):
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
use_deepep_ll
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep_ll
=
(
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
(
backend
==
"deepep_low_latency"
or
backend
==
"deepep_auto"
)
)
if
not
self
.
use_deepep_ll
:
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
MoriMoE
...
...
@@ -717,8 +721,12 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
dp_size
=
get_dp_group
().
world_size
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
use_deepep_ll
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep_ll
=
(
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
(
backend
==
"deepep_low_latency"
or
backend
==
"deepep_auto"
)
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
(
config
.
n_routed_experts
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment