Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dbd0bda6
"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "fda9537c5e61ea8226fa7e5b60912deda97a8aab"
Commit
dbd0bda6
authored
Aug 25, 2025
by
王敏
Browse files
临时上传大ep代码
parent
15347448
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1844 additions
and
48 deletions
+1844
-48
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+12
-1
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
+377
-0
vllm/model_executor/layers/fused_moe/ep_moe/kernels.py
vllm/model_executor/layers/fused_moe/ep_moe/kernels.py
+638
-0
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+253
-0
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+467
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+74
-41
vllm/v1/engine/utils.py
vllm/v1/engine/utils.py
+21
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-0
No files found.
vllm/distributed/communication_op.py
View file @
dbd0bda6
...
@@ -6,7 +6,7 @@ from typing import Any, Optional, Union
...
@@ -6,7 +6,7 @@ from typing import Any, Optional, Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
.parallel_state
import
get_tp_group
from
.parallel_state
import
get_tp_group
,
get_ep_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -32,6 +32,17 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
...
@@ -32,6 +32,17 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
"""Gather the input tensor across model parallel group."""
"""Gather the input tensor across model parallel group."""
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
def
expert_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
return
get_ep_group
().
all_gather
(
input_
,
dim
)
def
expert_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""Gather the input tensor across model parallel group."""
return
get_ep_group
().
gather
(
input_
,
dst
,
dim
)
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
dict
[
Any
,
Union
[
torch
.
Tensor
,
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
Any
]]]
=
None
,
...
...
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
0 → 100644
View file @
dbd0bda6
import
math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.distributed
import
(
get_dp_group
,
get_ep_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
try
:
from
transformer_engine.pytorch.permutation
import
(
moe_permute
,
moe_sort_chunks_by_index
,
moe_unpermute
,
)
fused_permute
=
moe_permute
fused_unpermute
=
moe_unpermute
fused_sort_chunks_by_index
=
moe_sort_chunks_by_index
HAVE_TE
=
True
except
ImportError
:
fused_permute
=
None
fused_unpermute
=
None
fused_sort_chunks_by_index
=
None
HAVE_TE
=
False
@
dataclass
class
EpMoeConfig
:
moe_router_topk
:
int
=
2
moe_permute_fusion
:
bool
=
False
moe_shared_expert_overlap
:
bool
=
False
ep_size
:
int
=
1
num_moe_experts
:
int
=
256
@
staticmethod
def
make
(
moe_router_topk
:
int
=
2
,
moe_permute_fusion
:
bool
=
False
,
moe_shared_expert_overlap
:
bool
=
False
,
ep_size
:
int
=
1
,
num_moe_experts
:
int
=
256
)
->
"EpMoeConfig"
:
return
EpMoeConfig
(
moe_router_topk
=
moe_router_topk
,
moe_permute_fusion
=
moe_permute_fusion
,
moe_shared_expert_overlap
=
moe_shared_expert_overlap
,
ep_size
=
ep_size
,
num_moe_experts
=
num_moe_experts
)
class
EPSharedExperts
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
moe_shared_expert_overlap
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
moe_shared_expert_overlap
=
moe_shared_expert_overlap
if
self
.
moe_shared_expert_overlap
:
self
.
cached_fc1_input
=
None
self
.
cached_fc2_input
=
None
self
.
cached_fc2_output
=
None
self
.
cached_output
=
None
self
.
gate_score
=
None
self
.
stream
=
torch
.
cuda
.
Stream
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
def
linear_fc1_forward_and_act
(
self
,
overlapped_comm_output
=
None
):
"""
Do Linear FC1 and activation function forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert
self
.
moe_shared_expert_overlap
with
torch
.
cuda
.
stream
(
self
.
stream
):
# [s, b, 4 * h/p]
intermediate_parallel
,
bias_parallel
=
self
.
gate_up_proj
(
self
.
cached_fc1_input
)
self
.
cached_fc1_input
=
None
if
bias_parallel
is
not
None
:
intermediate_parallel
=
intermediate_parallel
+
bias_parallel
intermediate_parallel
=
self
.
act_fn
(
intermediate_parallel
)
self
.
cached_fc2_input
=
intermediate_parallel
def
linear_fc2_forward
(
self
,
overlapped_comm_output
=
None
):
"""
Do Linear FC2 forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert
self
.
moe_shared_expert_overlap
assert
self
.
cached_fc2_input
is
not
None
with
torch
.
cuda
.
stream
(
self
.
stream
):
# [s, b, h]
self
.
cached_fc2_output
,
_
=
self
.
down_proj
(
self
.
cached_fc2_input
)
self
.
cached_fc2_input
=
None
def
pre_forward_comm
(
self
,
input
):
"""
All Gather for SP before forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert
self
.
cached_output
is
None
self
.
stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
cached_fc1_input
=
input
def
post_forward_comm
(
self
):
"""
Reduce scatter for SP after forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert
self
.
moe_shared_expert_overlap
assert
self
.
cached_fc2_output
is
not
None
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
cached_output
=
tensor_model_parallel_all_reduce
(
self
.
cached_fc2_output
)
self
.
cached_fc2_output
=
None
def
get_output
(
self
):
"""
Gets the module forward output.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert
self
.
moe_shared_expert_overlap
assert
self
.
cached_output
is
not
None
with
torch
.
cuda
.
stream
(
self
.
stream
):
output
=
self
.
cached_output
self
.
cached_output
=
None
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
return
output
def
maybe_move_tensor_to_cpu
(
tensor
,
as_numpy
=
False
,
record_stream
=
False
):
"""Move a tensor to CPU if it is on GPU.
Args:
tensor (torch.Tensor or None): The tensor to move to CPU.
as_numpy (bool): Whether to convert the tensor to a numpy array.
record_stream (bool): Whether to record the stream of the tensor, to prevent memory leak
when the DtoH data transfer is on a side stream.
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
is_cuda
:
cpu_tensor
=
tensor
.
to
(
torch
.
device
(
"cpu"
),
non_blocking
=
True
)
if
as_numpy
:
cpu_tensor
=
cpu_tensor
.
numpy
()
if
record_stream
:
tensor
.
record_stream
(
torch
.
cuda
.
current_stream
())
tensor
=
cpu_tensor
return
tensor
def
sort_chunks_by_idxs
(
input
:
torch
.
Tensor
,
split_sizes
:
torch
.
Tensor
,
sorted_idxs
:
torch
.
Tensor
,
fused
:
bool
=
False
):
"""Split and sort the input tensor based on the split_sizes and sorted indices."""
if
fused
:
if
not
HAVE_TE
or
fused_sort_chunks_by_index
is
None
:
raise
ValueError
(
"fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
)
return
fused_sort_chunks_by_index
(
input
,
split_sizes
,
sorted_idxs
)
input
=
torch
.
split
(
input
,
split_sizes
.
tolist
(),
dim
=
0
)
output
=
torch
.
cat
([
input
[
i
]
for
i
in
sorted_idxs
.
tolist
()],
dim
=
0
)
return
output
def
permute
(
tokens
,
routing_map
,
num_out_tokens
:
Optional
[
int
]
=
None
,
fused
:
bool
=
False
,
drop_and_pad
:
bool
=
False
,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if
fused
:
if
not
HAVE_TE
or
fused_permute
is
None
:
raise
ValueError
(
"fused_permute is not available. Please install TE >= 2.1.0."
)
return
fused_permute
(
tokens
,
routing_map
,
num_out_tokens
)
num_tokens
,
hidden
=
tokens
.
shape
num_experts
=
routing_map
.
shape
[
1
]
if
drop_and_pad
and
not
(
num_out_tokens
is
None
):
capacity
=
num_out_tokens
//
num_experts
assert
not
routing_map
.
requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map
=
routing_map
.
to
(
dtype
=
torch
.
int8
).
T
.
contiguous
()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices
=
routing_map
.
argsort
(
dim
=-
1
,
descending
=
True
,
stable
=
True
)[
:,
:
capacity
].
contiguous
()
# flatten from [num_experts, capacity] to 1D
sorted_indices
=
sorted_indices
.
view
(
-
1
)
else
:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map
=
routing_map
.
bool
().
T
.
contiguous
()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices
=
(
torch
.
arange
(
num_tokens
,
device
=
routing_map
.
device
).
unsqueeze
(
0
).
expand
(
num_experts
,
-
1
)
)
sorted_indices
=
token_indices
.
masked_select
(
routing_map
)
# use the mapping to permute the tokens
permuted_input
=
tokens
.
index_select
(
0
,
sorted_indices
)
return
permuted_input
,
sorted_indices
def
unpermute
(
permuted_tokens
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
restore_shape
:
torch
.
Size
,
probs
:
torch
.
Tensor
=
None
,
routing_map
:
torch
.
Tensor
=
None
,
fused
:
bool
=
False
,
drop_and_pad
:
bool
=
False
,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
This function exploits these features to use ops that support cuda graph.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
restore_shape (torch.Size): The shape of the unpermuted tensor.
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
"""
if
fused
:
if
not
HAVE_TE
or
fused_unpermute
is
None
:
raise
ValueError
(
"fused_unpermute is not available. Please install TE >= 2.1.0."
)
return
fused_unpermute
(
permuted_tokens
,
sorted_indices
,
probs
,
restore_shape
)
_
,
hidden
=
restore_shape
input_dtype
=
permuted_tokens
.
dtype
if
probs
is
not
None
:
assert
routing_map
is
not
None
,
"Mask must be provided to permute the probs."
if
drop_and_pad
:
num_experts
=
routing_map
.
size
(
1
)
num_permuted_tokens
=
sorted_indices
.
size
(
0
)
capacity
=
num_permuted_tokens
//
num_experts
num_unpermuted_tokens
=
probs
.
size
(
0
)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D
=
probs
.
T
.
contiguous
().
view
(
-
1
)
# get 1D indices of the probs selected by routing_map
indices_dim0
=
torch
.
arange
(
num_experts
,
device
=
routing_map
.
device
).
unsqueeze
(
-
1
)
indices_dim1
=
sorted_indices
.
view
(
num_experts
,
capacity
)
indices_1D
=
(
indices_dim0
*
num_unpermuted_tokens
+
indices_dim1
).
view
(
-
1
)
# get probs from indices
permuted_probs
=
probs_T_1D
.
index_select
(
0
,
indices_1D
)
else
:
permuted_probs
=
probs
.
T
.
contiguous
().
masked_select
(
routing_map
.
T
.
contiguous
())
# Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
# higher precision due to moe_router_dtype being enabled. This can lead to
# additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
# allocation.
permuted_tokens
=
permuted_tokens
*
permuted_probs
.
unsqueeze
(
-
1
)
# Create an output tensor filled with zeros
output_tokens
=
torch
.
zeros
(
restore_shape
,
dtype
=
permuted_tokens
.
dtype
,
device
=
permuted_tokens
.
device
)
# Scatter add the permuted_input back to the original positions
output_tokens
.
scatter_add_
(
0
,
sorted_indices
.
unsqueeze
(
1
).
expand
(
-
1
,
hidden
),
permuted_tokens
)
return
output_tokens
.
to
(
dtype
=
input_dtype
)
def
all_to_all
(
group
,
input
,
output_split_sizes
,
input_split_sizes
):
# torch.cuda.synchronize()
# import sys
# sys.stderr.write(f"############all_to_all input_split_sizes:{input_split_sizes}\n output_split_sizes:{output_split_sizes}")
# sys.stderr.flush()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
vllm/model_executor/layers/fused_moe/ep_moe/kernels.py
0 → 100644
View file @
dbd0bda6
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
0 → 100644
View file @
dbd0bda6
import
logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EPSharedExperts
,
EpMoeConfig
from
vllm.model_executor.layers.fused_moe.ep_moe.kernels
import
grouped_gemm_triton
logger
=
init_logger
(
__name__
)
class
EPMoE
(
FusedMoE
):
"""
dp+ep MoE Expert Parallel Impl
"""
def
__init__
(
self
,
num_experts
:
int
,
# Global number of experts
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
dp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
moe_permute_fusion
:
bool
=
False
,
moe_shared_expert_overlap
:
bool
=
False
):
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
intermediate_size
,
params_dtype
,
reduce_results
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
,
quant_config
,
tp_size
,
ep_size
,
dp_size
,
prefix
,
custom_routing_function
,
scoring_func
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
routed_scaling_factor
=
routed_scaling_factor
)
self
.
ep_moe_config
:
EpMoeConfig
=
EpMoeConfig
.
make
(
moe_router_topk
=
self
.
top_k
,
# TODO: support fusion permute
moe_permute_fusion
=
moe_permute_fusion
,
moe_shared_expert_overlap
=
moe_shared_expert_overlap
,
ep_size
=
self
.
ep_size
,
num_moe_experts
=
self
.
global_num_experts
)
local_expert_indices_offset
=
(
self
.
ep_rank
*
self
.
local_num_experts
)
self
.
local_expert_indices
=
[
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
]
self
.
shared_experts
=
None
self
.
use_shared_expert
=
False
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
)
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
self
.
seg_indptr
=
None
if
quant_config
is
None
:
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
w13_weight_scale
=
None
self
.
w2_weight_scale
=
None
else
:
self
.
use_fp8_w8a8
=
True
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
def
set_shared_experts
(
self
,
shared_experts
):
self
.
shared_experts
=
shared_experts
self
.
use_shared_expert
=
shared_experts
is
not
None
if
self
.
shared_expert_overlap
:
self
.
token_dispatcher
.
set_shared_experts
(
shared_experts
)
def
triton_grouped_gemm_impl
(
self
,
hidden_states
,
tokens_per_expert
,
use_nn_moe
):
torch
.
cumsum
(
tokens_per_expert
,
dim
=
0
,
out
=
self
.
seg_indptr
[
1
:])
_
,
N
,
_
=
self
.
w13_weight
.
shape
gateup_input
=
hidden_states
weight_indices_cur_rank
=
torch
.
arange
(
0
,
self
.
local_num_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
,
)
# GroupGemm-0
gateup_output
=
torch
.
empty
(
gateup_input
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
gateup_output
=
grouped_gemm_triton
(
a
=
gateup_input
,
b
=
self
.
w13_weight
,
c
=
gateup_output
,
batch_size
=
self
.
local_num_experts
,
weight_column_major
=
True
,
seg_indptr
=
self
.
seg_indptr
,
weight_indices
=
weight_indices_cur_rank
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
scale_a
=
self
.
w13_input_scale
if
self
.
quant_config
is
not
None
else
None
,
scale_b
=
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
)
if
self
.
quant_config
is
not
None
else
None
,
block_shape
=
self
.
block_shape
,
)
# Act
down_input
=
torch
.
empty
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
]
//
2
,
device
=
gateup_output
.
device
,
dtype
=
(
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states
.
dtype
),
)
if
self
.
quant_config
is
not
None
and
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
local_num_experts
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
)
if
self
.
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
N
))
elif
self
.
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
down_input
,
gateup_output
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
self
.
activation
}
"
)
# GroupGemm-1
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
self
.
w2_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
down_output
=
grouped_gemm_triton
(
a
=
down_input
,
b
=
self
.
w2_weight
,
c
=
down_output
,
batch_size
=
self
.
local_num_experts
,
weight_column_major
=
True
,
seg_indptr
=
self
.
seg_indptr
,
weight_indices
=
weight_indices_cur_rank
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
scale_a
=
self
.
w2_input_scale
if
self
.
quant_config
is
not
None
else
None
,
scale_b
=
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
)
if
self
.
quant_config
is
not
None
else
None
,
block_shape
=
self
.
block_shape
,
)
return
down_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
if
(
self
.
training
and
self
.
config
.
tensor_model_parallel_size
>
1
and
not
self
.
config
.
sequence_parallel
):
raise
ValueError
(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
if
self
.
seg_indptr
is
None
:
self
.
seg_indptr
=
torch
.
zeros
(
self
.
local_num_experts
+
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
# process MoE
def
custom_forward
(
hidden_states
,
router_logits
):
topk_weights
,
topk_ids
=
self
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
use_grouped_topk
=
self
.
use_grouped_topk
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
indices_type
=
torch
.
int64
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
probs
=
torch
.
zeros_like
(
router_logits
,
dtype
=
topk_weights
.
dtype
).
scatter
(
1
,
topk_ids
,
topk_weights
)
routing_map
=
torch
.
zeros_like
(
router_logits
).
int
().
scatter
(
1
,
topk_ids
,
1
).
bool
()
(
dispatched_input
,
tokens_per_expert
)
=
self
.
token_dispatcher
.
token_permutation
(
hidden_states
,
probs
,
routing_map
)
expert_output
=
self
.
triton_grouped_gemm_impl
(
dispatched_input
,
tokens_per_expert
,
self
.
use_nn_moe
)
output
=
self
.
token_dispatcher
.
token_unpermutation
(
expert_output
)
if
self
.
use_shared_expert
and
not
self
.
shared_expert_overlap
:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output
=
output
+
self
.
shared_experts
(
hidden_states
)
return
output
output
=
custom_forward
(
hidden_states
,
router_logits
)
return
output
\ No newline at end of file
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
0 → 100644
View file @
dbd0bda6
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/deepseek_v2.py
View file @
dbd0bda6
...
@@ -39,10 +39,12 @@ from vllm.attention import Attention
...
@@ -39,10 +39,12 @@ from vllm.attention import Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
get_current_vllm_config
)
get_current_vllm_config
)
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.ep_moe.layer
import
EPMoE
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EPSharedExperts
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -151,45 +153,71 @@ class DeepseekV2MoE(nn.Module):
...
@@ -151,45 +153,71 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
dp_size
=
get_dp_group
().
world_size
num_experts
=
config
.
n_routed_experts
,
self
.
use_ep_opt
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
top_k
=
config
.
num_experts_per_tok
,
self
.
shared_experts
=
None
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
self
.
shared_experts
=
DeepseekV2MLP
(
shared_expert_cls
=
DeepseekV2MLP
if
not
self
.
use_ep_opt
else
EPSharedExperts
self
.
shared_experts
=
shared_expert_cls
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
reduce_results
=
False
,
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
if
not
self
.
use_ep_opt
:
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
else
:
self
.
experts
=
EPMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
self
.
use_ep_opt
:
self
.
experts
.
set_shared_experts
(
self
.
shared_experts
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
not
self
.
use_ep_opt
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -203,22 +231,23 @@ class DeepseekV2MoE(nn.Module):
...
@@ -203,22 +231,23 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
not
self
.
use_ep_opt
:
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
hidden_states
.
dtype
!=
torch
.
float16
or
self
.
dpsk_fp16_quick
:
else
:
final_hidden_states
=
final_hidden_states
+
shared_output
# Fix FP16 overflow
else
:
# See DeepseekV2DecoderLayer for more details.
# Fix FP16 overflow
final_hidden_states
=
final_hidden_states
+
shared_output
\
# See DeepseekV2DecoderLayer for more details.
*
(
1.
/
self
.
routed_scaling_factor
)
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
if
envs
.
VLLM_ENABLE_TBO
:
else
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
final_hidden_states
=
(
else
:
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
=
(
final_hidden_states
))
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
@@ -619,6 +648,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -619,6 +648,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
)
)
#ops.print_tensor(hidden_states)
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Fix FP16 overflow
# We scale both hidden_states and residual before
# We scale both hidden_states and residual before
...
@@ -714,7 +745,9 @@ class DeepseekV2Model(nn.Module):
...
@@ -714,7 +745,9 @@ class DeepseekV2Model(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
\
#ops.print_tensor(hidden_states)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
...
vllm/v1/engine/utils.py
View file @
dbd0bda6
...
@@ -244,11 +244,18 @@ class CoreEngineActorManager:
...
@@ -244,11 +244,18 @@ class CoreEngineActorManager:
local_engine_count
=
\
local_engine_count
=
\
vllm_config
.
parallel_config
.
data_parallel_size_local
vllm_config
.
parallel_config
.
data_parallel_size_local
nodes
=
sorted
(
list_nodes
(),
# nodes = sorted(list_nodes(),
key
=
lambda
node
:
node
.
node_ip
!=
dp_master_ip
)
# key=lambda node: node.node_ip != dp_master_ip)
assert
nodes
[
0
].
node_ip
==
dp_master_ip
,
(
# assert nodes[0].node_ip == dp_master_ip, (
# "The first node must be the head node")
# assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
# "There can only be one head node")
nodes
=
ray
.
nodes
()
nodes
=
sorted
(
nodes
,
key
=
lambda
node
:
node
[
"NodeManagerAddress"
]
!=
dp_master_ip
)
assert
nodes
[
0
][
"NodeManagerAddress"
]
==
dp_master_ip
,
(
"The first node must be the head node"
)
"The first node must be the head node"
)
assert
len
(
nodes
)
==
1
or
nodes
[
1
]
.
node_ip
!=
dp_master_ip
,
(
assert
len
(
nodes
)
==
1
or
nodes
[
1
]
[
"NodeManagerAddress"
]
!=
dp_master_ip
,
(
"There can only be one head node"
)
"There can only be one head node"
)
available_resources
=
available_resources_per_node
()
available_resources
=
available_resources_per_node
()
...
@@ -257,8 +264,11 @@ class CoreEngineActorManager:
...
@@ -257,8 +264,11 @@ class CoreEngineActorManager:
local_dp_ranks
:
list
[
int
]
=
[]
local_dp_ranks
:
list
[
int
]
=
[]
for
node
in
nodes
:
for
node
in
nodes
:
node_ip
=
node
.
node_ip
# node_ip = node.node_ip
node_resources
=
available_resources
[
node
.
node_id
]
# node_resources = available_resources[node.node_id]
node_ip
=
node
[
"NodeManagerAddress"
]
node_resources
=
available_resources
[
node
[
"NodeID"
]]
# For now, each DP rank can only be assigned to one node
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# TODO(rui): support allocating a single DP rank
# to multiple nodes
# to multiple nodes
...
@@ -428,6 +438,9 @@ def launch_core_engines(
...
@@ -428,6 +438,9 @@ def launch_core_engines(
else
:
else
:
local_engine_manager
=
None
local_engine_manager
=
None
import
torch
torch
.
cuda
.
synchronize
()
logger
.
info
((
"launch_core_engines end==============================="
))
yield
local_engine_manager
,
coordinator
,
addresses
yield
local_engine_manager
,
coordinator
,
addresses
# Now wait for engines to start.
# Now wait for engines to start.
...
@@ -440,6 +453,8 @@ def launch_core_engines(
...
@@ -440,6 +453,8 @@ def launch_core_engines(
local_engine_manager
,
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
coordinator
.
proc
if
coordinator
else
None
,
)
)
torch
.
cuda
.
synchronize
()
logger
.
info
((
"engine startup==============================="
))
def
wait_for_engine_startup
(
def
wait_for_engine_startup
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
dbd0bda6
...
@@ -2051,6 +2051,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2051,6 +2051,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
else
:
else
:
#self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32)
self
.
input_ids
[:
num_tokens
]
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
)
input_ids
=
self
.
input_ids
[:
num_tokens
]
input_ids
=
self
.
input_ids
[:
num_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment