Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fadef92
Commit
4fadef92
authored
Dec 17, 2025
by
王敏
Browse files
[feat]修复低延迟错误
parent
10400c58
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
78 deletions
+33
-78
vllm/config.py
vllm/config.py
+1
-0
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+1
-1
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+0
-9
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+20
-16
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+1
-35
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+0
-6
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+10
-11
No files found.
vllm/config.py
View file @
4fadef92
...
...
@@ -4783,6 +4783,7 @@ class VllmConfig:
mtp_batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
batch_size_capture_list
=
sorted
(
set
(
batch_size_capture_list
+
mtp_batch_size_capture_list
))
batch_size_capture_list
=
[
i
for
i
in
batch_size_capture_list
if
i
==
1
or
i
%
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
==
0
]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
...
...
vllm/distributed/device_communicators/all2all.py
View file @
4fadef92
...
...
@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self
.
num_sms
=
24
#2
0
self
.
num_sms
=
3
0
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
4fadef92
...
...
@@ -298,15 +298,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
->
Callable
|
None
:
assert
self
.
handle
is
not
None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
# if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
# fused_expert_output = self._apply_weights_and_reduce(
# num_tokens=topk_ids.size(0),
# fused_expert_output=fused_expert_output,
# topk_weights=topk_weights,
# apply_router_weight_on_input=apply_router_weight_on_input,
# output_dtype=output.dtype)
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
# HT combine only supports BF16
x
=
fused_expert_output
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
4fadef92
...
...
@@ -74,18 +74,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
_do_quant
(
self
,
x
:
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
,
x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
a1_dtype
:
torch
.
dtype
,
quant_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_config
:
FusedMoEQuantConfig
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
block_k
=
block_shape
[
1
]
if
block_shape
is
not
None
else
None
if
self
.
use_fp8_dispatch
:
block_k
=
(
quant_config
.
block_shape
[
1
]
if
quant_config
.
block_shape
is
not
None
else
None
)
if
block_k
==
DEEPEP_QUANT_BLOCK_SIZE
:
# DeepEP kernels did the quantization for us.
x
,
x_scales
=
x
...
...
@@ -102,14 +103,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
if
expert_num_tokens
is
None
:
x
=
x
.
view
((
-
1
,
hidden_dim
))
x
,
x_scales
=
moe_kernel_quantize_input
(
x
,
a1_scale
,
quant_dtype
,
per_act_token_quant
,
block_shape
,
expert_num_tokens
)
if
expert_num_tokens
is
None
:
x
,
x_scales
=
moe_kernel_quantize_input
(
x
,
a1_scale
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
x
=
x
.
view
((
num_experts
,
-
1
,
hidden_dim
))
if
quant_dtype
is
not
None
:
if
quant_config
.
quant_dtype
is
not
None
:
assert
x_scales
is
not
None
x_scales
=
normalize_batched_scales_shape
(
x_scales
,
num_experts
)
...
...
@@ -151,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
# Dispatch
expert_x
,
expert_num_tokens
,
self
.
handle
s
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
expert_x
,
expert_num_tokens
,
self
.
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
topk_ids
,
self
.
max_tokens_per_rank
,
...
...
@@ -181,7 +185,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_dtype
,
quant_config
)
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_scale
,
a1_dtype
,
quant_config
,
expert_num_tokens
)
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
4fadef92
...
...
@@ -12,7 +12,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.utils
import
cdiv
,
async_tensor_h2d
from
vllm.utils
import
cdiv
#
# This file defines a set of base classes used to make MoE kernels more modular.
...
...
@@ -112,9 +112,6 @@ class ExpertTokensMetadata:
def
make_from_list
(
expert_num_tokens_list
:
list
[
int
],
device
:
str
)
->
"ExpertTokensMetadata"
:
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu
=
torch
.
tensor
(
expert_num_tokens_list
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
...
...
@@ -813,21 +810,6 @@ class FusedMoEModularKernel(torch.nn.Module):
return
output
_aux_stream
:
torch
.
cuda
.
Stream
|
None
=
None
def
aux_stream
()
->
torch
.
cuda
.
Stream
|
None
:
"""
Ensures aux_stream is initialized only once
"""
global
_aux_stream
# TODO: validate this works properly on ROCm platform.
if
_aux_stream
is
None
:
_aux_stream
=
torch
.
cuda
.
Stream
()
return
_aux_stream
@
final
class
DeepGemmDisabledFusedMoEModularKernel
(
torch
.
nn
.
Module
):
"""
...
...
@@ -853,10 +835,6 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self
.
fused_experts
=
fused_experts
self
.
shared_experts
=
shared_experts
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts_stream
=
aux_stream
()
self
.
shared_experts_overlap_event
=
torch
.
cuda
.
Event
()
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
...
...
@@ -933,18 +911,6 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
# (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# _expert_topk_weights) = self.prepare_finalize.prepare(
# a1,
# a1_scale,
# a2_scale,
# topk_weights,
# topk_ids,
# global_num_experts,
# expert_map,
# apply_router_weight_on_input,
# self.fused_experts.quant_config,
# )
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
a1_scale
,
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
4fadef92
...
...
@@ -812,14 +812,8 @@ def deepgemm_moe_permute(
aq_scale_out
=
torch
.
empty
(
(
M_sum
,
aq_scale
.
shape
[
-
1
]),
device
=
device
,
dtype
=
torch
.
float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids
=
torch
.
full
(
(
M_sum
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
4fadef92
...
...
@@ -91,8 +91,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
#self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
self
.
use_deepgemm
=
False
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
...
...
@@ -332,6 +331,12 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
mm1_out
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
# act_out = _resize_cache(workspace2.view(dtype=torch.int8), (M_sum, N // 2))
# act_out = _resize_cache(
# workspace13.view(dtype=torch.int8), (M_sum, N // 2)
# )
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
K
))
...
...
@@ -349,18 +354,12 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
M_sum
=
M_sum
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
#a2q, a2q_scale = fuse_silu_mul_quant(mm1_out, expert_ids=expert_ids)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, output=act_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment