Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc2f9b32
Unverified
Commit
cc2f9b32
authored
Mar 06, 2025
by
Tyler Michael Smith
Committed by
GitHub
Mar 06, 2025
Browse files
[Distributed] Add enable_expert_parallel arg (#14305)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
cd579352
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
21 deletions
+27
-21
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+3
-3
vllm/config.py
vllm/config.py
+2
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-0
vllm/envs.py
vllm/envs.py
+0
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+15
-10
No files found.
examples/offline_inference/data_parallel.py
View file @
cc2f9b32
# SPDX-License-Identifier: Apache-2.0
# usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import
os
...
...
@@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
# Create an LLM.
llm
=
LLM
(
model
=
"ibm-research/PowerMoE-3b"
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
True
)
enforce_eager
=
True
,
enable_expert_parallel
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
...
...
vllm/config.py
View file @
cc2f9b32
...
...
@@ -754,7 +754,7 @@ class ModelConfig:
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
if
envs
.
VLLM_TEST_ENABLE_EP
:
if
parallel_config
.
enable_expert_parallel
:
self
.
_verify_with_expert_parallelism
()
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
...
...
@@ -1334,6 +1334,7 @@ class ParallelConfig:
# IP of the data parallel master.
data_parallel_master_ip
:
str
=
"127.0.0.1"
data_parallel_master_port
:
int
=
29500
# Port of the data parallel master.
enable_expert_parallel
:
bool
=
False
# Use EP instead of TP for MoE layers.
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
...
...
vllm/engine/arg_utils.py
View file @
cc2f9b32
...
...
@@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
enable_expert_parallel
:
bool
=
False
max_parallel_loading_workers
:
Optional
[
int
]
=
None
block_size
:
Optional
[
int
]
=
None
enable_prefix_caching
:
Optional
[
bool
]
=
None
...
...
@@ -440,6 +441,11 @@ class EngineArgs:
type
=
int
,
default
=
EngineArgs
.
tensor_parallel_size
,
help
=
'Number of tensor parallel replicas.'
)
parser
.
add_argument
(
'--enable-expert-parallel'
,
action
=
'store_true'
,
help
=
'Use expert parallelism instead of tensor parallelism '
'for MoE layers.'
)
parser
.
add_argument
(
'--max-parallel-loading-workers'
,
type
=
int
,
...
...
@@ -1207,6 +1213,7 @@ class EngineArgs:
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
tokenizer_pool_config
=
TokenizerPoolConfig
.
create_config
(
...
...
vllm/envs.py
View file @
cc2f9b32
...
...
@@ -86,7 +86,6 @@ if TYPE_CHECKING:
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
bool
=
True
VLLM_MLA_DISABLE_REQUANTIZATION
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_TEST_ENABLE_EP
:
bool
=
False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
...
...
@@ -579,12 +578,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
),
# If set, vLLM will use the experimental expert parallel implementation on
# the FusedMoE layer, using tensor parallelism size as expert parallelism
# size.
"VLLM_TEST_ENABLE_EP"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_TEST_ENABLE_EP"
,
"0"
))),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
cc2f9b32
...
...
@@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple
import
torch
from
torch.nn.parameter
import
UninitializedParameter
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
get_dp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -342,14 +341,6 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
# For smuggling this layer into the fused moe custom op
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
use_direct_call
=
not
envs
.
VLLM_TEST_ENABLE_EP
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
...
...
@@ -361,7 +352,21 @@ class FusedMoE(torch.nn.Module):
if
self
.
dp_size
==
1
else
get_dp_group
().
rank_in_group
)
self
.
global_num_experts
=
num_experts
if
envs
.
VLLM_TEST_ENABLE_EP
:
# Use expert parallelism instead of tensor parallelism?
vllm_config
=
get_current_vllm_config
()
use_ep
=
(
vllm_config
.
parallel_config
.
enable_expert_parallel
and
self
.
tp_size
>
1
)
# For smuggling this layer into the fused moe custom op
self
.
use_direct_call
=
self
.
dp_size
==
1
if
self
.
use_direct_call
:
compilation_config
=
vllm_config
.
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
if
use_ep
:
# Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self
.
ep_rank
=
tp_rank
+
self
.
tp_size
*
self
.
dp_rank
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment