Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b3bb3ae
Commit
6b3bb3ae
authored
Mar 18, 2026
by
chenhw5
Browse files
sbo-deepep-gemm based on v0.9.2-dev-0316-dp
parent
236266a9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
495 additions
and
79 deletions
+495
-79
vllm/envs.py
vllm/envs.py
+5
-1
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+47
-13
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
...executor/layers/fused_moe/deepep_auto_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+149
-12
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+277
-42
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-0
vllm/version.py
vllm/version.py
+12
-11
No files found.
vllm/envs.py
View file @
6b3bb3ae
...
@@ -203,6 +203,7 @@ if TYPE_CHECKING:
...
@@ -203,6 +203,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_EP_USE_SBO
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH
:
bool
=
True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH
:
bool
=
True
VLLM_ZERO_OVERHEAD_ENHANCE
:
bool
=
False
VLLM_ZERO_OVERHEAD_ENHANCE
:
bool
=
False
...
@@ -1332,7 +1333,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1332,7 +1333,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
:
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
:
lambda
:
(
os
.
getenv
(
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
,
lambda
:
(
os
.
getenv
(
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
"0"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use single batch overlapping optimization
"VLLM_EP_USE_SBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_EP_USE_SBO"
,
"0"
))),
# vLLM will use deepgemm kernel for deepep ht mode
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
6b3bb3ae
...
@@ -10,6 +10,17 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
...
@@ -10,6 +10,17 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.import_utils
import
has_deep_gemm
from
lightop
import
fuse_silu_mul_quant_ep
,
fuse_silu_mul_fp8_quant_ep
if
has_deep_gemm
():
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_fp8_gemm_nt_masked
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_fp8_gemm_nt_masked
from
typing
import
Any
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -261,6 +272,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -261,6 +272,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
w2_gemm_overlap_args
:
Any
=
None
,
meta_overlap_args
:
dict
[
str
,
Any
]
|
None
=
None
,
):
):
import
deep_gemm
as
dg
import
deep_gemm
as
dg
assert
hidden_states
.
ndim
==
3
assert
hidden_states
.
ndim
==
3
...
@@ -281,18 +294,39 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -281,18 +294,39 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# may lead to better performance.
# may lead to better performance.
expected_m
=
max_num_tokens
expected_m
=
max_num_tokens
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a1q
,
a1q_scale
),
m_grouped_fp8_gemm_nt_masked
(
(
w1
,
w1_scale
),
(
a1q
,
a1q_scale
),
out
=
workspace1
,
(
w1
,
w1_scale
),
masked_m
=
expert_num_tokens
,
workspace1
,
expected_m
=
expected_m
)
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
expert_num_tokens
)
# When using SBO, we record event here to indicate
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a2q
,
a2q_scale
),
# that the signal tensor and input to deepep ll combine
(
w2
,
w2_scale
),
# are ready
out
=
output
,
enable_overlap
=
w2_gemm_overlap_args
is
not
None
masked_m
=
expert_num_tokens
,
signal
=
w2_gemm_overlap_args
.
signal
if
enable_overlap
else
None
expected_m
=
expected_m
)
if
enable_overlap
:
w2_gemm_overlap_args
.
start_event
.
record
()
block_m
,
threshold
=
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
output
,
expert_num_tokens
,
expected_m
,
enable_overlap
,
signal
,
)
# return meta_overlap_args to DeepEP combine.
if
meta_overlap_args
is
not
None
:
if
block_m
is
not
None
:
meta_overlap_args
[
"block_m"
]
=
block_m
if
threshold
is
not
None
:
meta_overlap_args
[
"threshold"
]
=
threshold
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
View file @
6b3bb3ae
...
@@ -100,6 +100,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -100,6 +100,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
6b3bb3ae
...
@@ -245,6 +245,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -245,6 +245,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
6b3bb3ae
...
@@ -15,6 +15,47 @@ from vllm.model_executor.layers.fused_moe.utils import (
...
@@ -15,6 +15,47 @@ from vllm.model_executor.layers.fused_moe.utils import (
DEEPEP_QUANT_BLOCK_SIZE
=
128
DEEPEP_QUANT_BLOCK_SIZE
=
128
DEEPEP_QUANT_BLOCK_SHAPE
=
[
DEEPEP_QUANT_BLOCK_SIZE
,
DEEPEP_QUANT_BLOCK_SIZE
]
DEEPEP_QUANT_BLOCK_SHAPE
=
[
DEEPEP_QUANT_BLOCK_SIZE
,
DEEPEP_QUANT_BLOCK_SIZE
]
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
alt_stream
=
torch
.
cuda
.
Stream
()
from
vllm
import
envs
@
dataclass
class
CombineOverlapArgs
:
# Whether to use overlap for w2 gemm and deepep ll combine
overlap
:
bool
# we launch deepep ll combine on this stream, which
# is different from the default compute stream
stream
:
torch
.
cuda
.
Stream
# We record this wait even in the compute stream between
# silu_mul_fp4_quantize and w2 gemm.
# And we wait for this even before deepep ll combine on the
# combine stream to ensure signal tensors have been allocated
wait_event
:
torch
.
cuda
.
Event
# Number of CU used for combine kernel, currently hardcoded to be 32
num_sms
:
int
# The signal tensor is shared by the w2 gemm and deepep ll combine.
# w2 gemm atomic_add to the tensor to signal deepep combine can start
# send data
signal
:
torch
.
Tensor
|
None
=
None
#
block_m
:
int
=
64
# Set to the number of CU used by W2 gemm, which is a persistent kernel
# So when all CU has completed the computation for an expert,
# combine kernel can start to send data for this expert
threshold
:
int
=
32
@
dataclass
class
W2GemmOverlapArgs
:
# Number of CU used by W2 gemm
num_sms
:
int
# Same signal tensor mentioned above
signal
:
torch
.
Tensor
# Same as the wait_even in CombineOverlapArgs
start_event
:
torch
.
cuda
.
Event
def
dequant_fp8
(
expert_x_fp8
:
torch
.
Tensor
,
def
dequant_fp8
(
expert_x_fp8
:
torch
.
Tensor
,
expert_x_scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expert_x_scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -59,6 +100,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -59,6 +100,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self
.
handle
=
None
self
.
handle
=
None
self
.
num_dispatchers_
=
num_dispatchers
self
.
num_dispatchers_
=
num_dispatchers
# SBO: DeepEP LL overlap 配置
self
.
combine_overlap_args
:
CombineOverlapArgs
|
None
=
None
self
.
meta_overlap_args
:
dict
[
str
,
Any
]
|
None
=
None
self
.
packed_recv_count
:
torch
.
Tensor
|
None
=
None
def
num_dispatchers
(
self
)
->
int
:
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
return
self
.
num_dispatchers_
...
@@ -118,6 +164,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -118,6 +164,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x_scales
=
normalize_batched_scales_shape
(
x_scales
,
num_experts
)
x_scales
=
normalize_batched_scales_shape
(
x_scales
,
num_experts
)
return
x
,
x_scales
return
x
,
x_scales
def
supports_async
(
self
)
->
bool
:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return
True
def
prepare_async
(
def
prepare_async
(
self
,
self
,
...
@@ -127,6 +181,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -127,6 +181,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
@@ -166,6 +221,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -166,6 +221,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
# We need to pass w2_gemm_overlap_args to moe implementation,
# so return it as an output paramter
w2_gemm_overlap_args
=
self
.
_create_sbo_args
(
local_num_experts
,
a1
.
device
)
return
(
return
(
hook
,
hook
,
lambda
:
self
.
_receiver
(
lambda
:
self
.
_receiver
(
...
@@ -175,6 +234,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -175,6 +234,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1
.
dtype
,
a1
.
dtype
,
quant_config
,
quant_config
,
),
),
w2_gemm_overlap_args
,
)
)
def
_receiver
(
def
_receiver
(
...
@@ -194,6 +254,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -194,6 +254,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
)
)
# SBO: save last expert_num_tokens for packed_recv_count,needed by deepep ll combine when use SBO.
self
.
packed_recv_count
=
expert_num_tokens
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
def
prepare
(
def
prepare
(
...
@@ -204,17 +267,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -204,17 +267,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
)
->
mk
.
PrepareResultType
:
hook
,
receiver
=
self
.
prepare_async
(
hook
,
receiver
,
_
=
self
.
prepare_async
(
a1
,
a1
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
local_num_experts
,
expert_map
,
expert_map
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
quant_config
,
quant_config
,
...
@@ -241,18 +306,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -241,18 +306,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights
=
torch
.
ones_like
(
topk_weights
)
combine_topk_weights
=
torch
.
ones_like
(
topk_weights
)
# TODO (varun) : Enable zero copy mode
# TODO (varun) : Enable zero copy mode
_
,
_
,
recv_hook
=
self
.
buffer
.
low_latency_combine
(
ctx
=
nullcontext
()
fused_expert_output
,
if
self
.
combine_overlap_args
is
not
None
:
topk_ids
,
# For SBO, we need to wait for compute stream
combine_topk_weights
,
# to have completed signal tensor allocation
self
.
handle
,
self
.
combine_overlap_args
.
stream
.
wait_event
(
async_finish
=
False
,
self
.
combine_overlap_args
.
wait_event
zero_copy
=
False
,
)
return_recv_hook
=
do_recv_hook
,
# And we launch ll combine phase 1 in a separate stream
out
=
output
,
# for overlaping
)
ctx
=
torch
.
cuda
.
stream
(
self
.
combine_overlap_args
.
stream
)
overlap_args_dict
=
dict
(
overlap
=
self
.
combine_overlap_args
.
overlap
,
packed_recv_count
=
self
.
packed_recv_count
,
comp_signal
=
self
.
combine_overlap_args
.
signal
,
block_m
=
self
.
meta_overlap_args
[
"block_m"
],
threshold
=
self
.
meta_overlap_args
[
"threshold"
],
num_sms
=
self
.
combine_overlap_args
.
num_sms
,
)
else
:
overlap_args_dict
=
{}
with
ctx
:
_
,
_
,
recv_hook
=
self
.
buffer
.
low_latency_combine
(
fused_expert_output
,
combine_topk_ids
,
combine_topk_weights
,
handle
,
async_finish
=
False
,
zero_copy
=
False
,
return_recv_hook
=
do_recv_hook
,
out
=
output
,
**
overlap_args_dict
,
)
if
self
.
combine_overlap_args
is
not
None
:
return
recv_hook
,
lambda
:
self
.
_sbo_wait_stream
()
else
:
return
recv_hook
,
lambda
:
None
return
recv_hook
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
...
@@ -283,3 +375,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -283,3 +375,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_weights_and_reduce
,
apply_weights_and_reduce
,
do_async
=
False
,
do_async
=
False
,
)
)
def
_create_sbo_args
(
self
,
local_num_experts
:
int
,
device
:
torch
.
device
)
->
W2GemmOverlapArgs
|
None
:
# None when SBO is not enabled
w2_gemm_overlap_args
=
None
self
.
combine_overlap_args
=
None
if
not
envs
.
VLLM_EP_USE_SBO
:
self
.
meta_overlap_args
=
None
return
None
else
:
# SBO enabled
self
.
meta_overlap_args
=
{}
# empty every time, avoid use history args.
total_num_sms
=
torch
.
cuda
.
get_device_properties
(
device
=
device
).
multi_processor_count
communicate_num_sms
=
32
compute_num_sms
=
total_num_sms
-
communicate_num_sms
combine_wait_event
=
torch
.
cuda
.
Event
()
combine_overlap_args
=
CombineOverlapArgs
(
num_sms
=
communicate_num_sms
,
stream
=
alt_stream
,
wait_event
=
combine_wait_event
,
)
combine_signal
=
torch
.
zeros
(
local_num_experts
,
dtype
=
torch
.
uint32
,
device
=
device
)
w2_gemm_overlap_args
=
W2GemmOverlapArgs
(
signal
=
combine_signal
,
start_event
=
combine_wait_event
,
num_sms
=
compute_num_sms
,
)
combine_overlap_args
.
overlap
=
True
combine_overlap_args
.
signal
=
combine_signal
self
.
combine_overlap_args
=
combine_overlap_args
return
w2_gemm_overlap_args
def
_sbo_wait_stream
(
self
)
->
None
:
# When SBO enabled, ll combine phase 2 is still launched
# on the main compute stream, but we need to wait for
# ll combine 1 to complete
if
self
.
combine_overlap_args
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
combine_overlap_args
.
stream
)
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
6b3bb3ae
...
@@ -502,6 +502,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -502,6 +502,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
6b3bb3ae
...
@@ -15,6 +15,9 @@ import vllm.envs as envs
...
@@ -15,6 +15,9 @@ import vllm.envs as envs
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
#
#
# This file defines a set of base classes used to make MoE kernels more modular.
# This file defines a set of base classes used to make MoE kernels more modular.
...
@@ -163,6 +166,7 @@ class FusedMoEPrepareAndFinalize(ABC):
...
@@ -163,6 +166,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
@@ -192,6 +196,14 @@ class FusedMoEPrepareAndFinalize(ABC):
...
@@ -192,6 +196,14 @@ class FusedMoEPrepareAndFinalize(ABC):
- Optional dispatched expert topk weight
- Optional dispatched expert topk weight
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
supports_async
(
self
)
->
bool
:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return
False
@
abstractmethod
@
abstractmethod
def
finalize
(
def
finalize
(
...
@@ -610,6 +622,184 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -610,6 +622,184 @@ class FusedMoEModularKernel(torch.nn.Module):
f
"
{
prepare_finalize
.
activation_format
}
== "
f
"
{
prepare_finalize
.
activation_format
}
== "
f
"
{
fused_experts
.
__class__
.
__name__
}
."
f
"
{
fused_experts
.
__class__
.
__name__
}
."
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
if
self
.
shared_experts
is
not
None
:
self
.
alt_stream
=
alt_stream
()
self
.
alt_event
=
torch
.
cuda
.
Event
()
def
_prepare
(
self
,
hidden_states
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
ExpertTokensMetadata
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
object
|
None
,
]:
"""
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
w2_gemm_overlap_args
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# TODO(lucas): enable in follow-up
#assert not dbo_enabled()
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
,
)
=
self
.
prepare_finalize
.
prepare
(
hidden_states
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
global_num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
else
:
# Overlap shared expert compute with all2all dispatch.
#dbo_maybe_run_recv_hook()
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
hidden_states
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
global_num_experts
,
local_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook
,
receiver
,
w2_gemm_overlap_args
=
(
prepare_ret
if
isinstance
(
prepare_ret
,
tuple
)
else
(
None
,
prepare_ret
)
)
if
hook
is
not
None
:
# if dbo_enabled():
# # If DBO is being used, register the hook with the ubatch
# # context and call it in dbo_maybe_run_recv_hook instead of
# # passing it to the receiver.
# dbo_register_recv_hook(hook)
# dbo_yield()
# else:
hook
()
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
,
)
=
receiver
()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids
=
topk_ids
if
_expert_topk_ids
is
None
else
_expert_topk_ids
topk_weights
=
(
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
return
a1q
,
a1q_scale
,
expert_tokens_meta
,
topk_ids
,
topk_weights
,
w2_gemm_overlap_args
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_out
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output
:
torch
.
Tensor
|
None
=
None
if
not
self
.
prepare_finalize
.
supports_async
():
#assert not dbo_enabled()
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
self
.
alt_event
.
record
()
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
self
.
alt_stream
.
wait_event
(
self
.
alt_event
)
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook
,
receiver
=
(
finalize_ret
if
isinstance
(
finalize_ret
,
tuple
)
else
(
None
,
finalize_ret
)
)
if
hook
is
not
None
:
# if dbo_enabled():
# # If DBO is being used, register the hook with the ubatch
# # context and call it in dbo_maybe_run_recv_hook instead of
# # passing it to the receiver.
# dbo_register_recv_hook(hook)
# dbo_yield()
# else:
hook
()
receiver
()
self
.
alt_event
.
record
()
current_stream
.
wait_event
(
self
.
alt_event
)
if
self
.
shared_experts
is
None
:
return
output
else
:
assert
shared_output
is
not
None
return
shared_output
,
output
def
forward
(
def
forward
(
self
,
self
,
...
@@ -674,13 +864,14 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -674,13 +864,14 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts
=
local_num_experts
global_num_experts
=
local_num_experts
(
a1q
,
a1q_scale
,
expert_num_tokens
,
_expert_topk_ids
,
(
a1q
,
a1q_scale
,
expert_num_tokens
,
_expert_topk_ids
,
_expert_topk_weights
)
=
self
.
prepare_finalize
.
prepare
(
_expert_topk_weights
,
w2_gemm_overlap_args
)
=
self
.
_
prepare
(
a1
,
a1
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_num_experts
,
global_num_experts
,
local_num_experts
,
expert_map
,
expert_map
,
apply_router_weight_on_input
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
self
.
fused_experts
.
quant_config
,
...
@@ -739,26 +930,48 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -739,26 +930,48 @@ class FusedMoEModularKernel(torch.nn.Module):
if
num_chunks
==
1
:
if
num_chunks
==
1
:
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
if
isinstance
(
self
.
fused_experts
,
BatchedDeepGemmExperts
):
self
.
fused_experts
.
apply
(
self
.
fused_experts
.
apply
(
fused_out
,
fused_out
,
a1q
,
a1q
,
w1
,
w1
,
w2
,
w2
,
topk_ids
,
topk_ids
,
activation
=
activation
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
a1q_scale
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
workspace13
=
workspace13
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
workspace2
=
workspace2
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens
=
expert_num_tokens
,
)
w2_gemm_overlap_args
=
w2_gemm_overlap_args
,
meta_overlap_args
=
self
.
prepare_finalize
.
meta_overlap_args
,
)
else
:
self
.
fused_experts
.
apply
(
fused_out
,
a1q
,
w1
,
w2
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_num_tokens
=
expert_num_tokens
,
)
else
:
else
:
# The leading output dimension may not be equal to M, so
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
# we compute output indices separately.
...
@@ -786,28 +999,50 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -786,28 +999,50 @@ class FusedMoEModularKernel(torch.nn.Module):
curr_a2_scale
=
_chunk_scales
(
a2_scale
,
begin_chunk_idx
,
curr_a2_scale
=
_chunk_scales
(
a2_scale
,
begin_chunk_idx
,
end_chunk_idx
)
end_chunk_idx
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
if
isinstance
(
self
.
fused_experts
,
BatchedDeepGemmExperts
):
self
.
fused_experts
.
apply
(
fused_out
[
begin_out_idx
:
end_out_idx
],
curr_a1q
,
w1
,
w2
,
curr_topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
curr_a1q_scale
,
a2_scale
=
curr_a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_num_tokens
=
expert_num_tokens
,
w2_gemm_overlap_args
=
w2_gemm_overlap_args
,
meta_overlap_args
=
self
.
prepare_finalize
.
meta_overlap_args
,
)
else
:
self
.
fused_experts
.
apply
(
fused_out
[
begin_out_idx
:
end_out_idx
],
curr_a1q
,
w1
,
w2
,
curr_topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
curr_a1q_scale
,
a2_scale
=
curr_a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_num_tokens
=
expert_num_tokens
,
)
self
.
fused_experts
.
apply
(
self
.
_finalize
(
output
,
fused_out
,
hidden_states
,
topk_weights
,
fused_out
[
begin_out_idx
:
end_out_idx
],
curr_a1q
,
w1
,
w2
,
curr_topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
curr_a1q_scale
,
a2_scale
=
curr_a2_scale
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
expert_num_tokens
=
expert_num_tokens
,
)
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
)
topk_ids
,
apply_router_weight_on_input
)
return
output
return
output
...
...
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
6b3bb3ae
...
@@ -91,6 +91,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -91,6 +91,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
6b3bb3ae
...
@@ -35,6 +35,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
...
@@ -35,6 +35,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
local_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
...
...
vllm/version.py
View file @
6b3bb3ae
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.9.2"
__version_tuple__
=
(
0
,
9
,
2
)
__hcu_version__
=
f
'0.9.2+das.opt6.dtk25044'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
except
Exception
as
e
:
import
warnings
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
RuntimeWarning
,
stacklevel
=
2
)
stacklevel
=
2
)
__version__
=
"dev"
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
Return True if version_str matches the previous minor version.
...
@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
...
@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
return
True
# Note - this won't do the right thing when we release 1.0!
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment