Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e64afa45
Unverified
Commit
e64afa45
authored
Mar 26, 2025
by
youkaichao
Committed by
GitHub
Mar 26, 2025
Browse files
multi-node offline DP+EP example (#15484)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
1711b929
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
23 deletions
+97
-23
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+97
-23
No files found.
examples/offline_inference/data_parallel.py
View file @
e64afa45
# SPDX-License-Identifier: Apache-2.0
# usage:
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
"""
Usage:
Single node:
python examples/offline_inference/data_parallel.py
\
--model="ibm-research/PowerMoE-3b"
\
--dp-size=2
\
--tp-size=2
Multi-node:
Node 0 (assume the node has ip of 10.99.48.128):
python examples/offline_inference/data_parallel.py
\
--model="ibm-research/PowerMoE-3b"
\
--dp-size=2
\
--tp-size=2
\
--node-size=2
\
--node-rank=0
\
--master-addr=10.99.48.128
\
--master-port=13345
Node 1:
python examples/offline_inference/data_parallel.py
\
--model="ibm-research/PowerMoE-3b"
\
--dp-size=2
\
--tp-size=2
\
--node-size=2
\
--node-rank=1
\
--master-addr=10.99.48.128
\
--master-port=13345
"""
import
os
from
vllm
import
LLM
,
SamplingParams
from
vllm.utils
import
get_open_port
GPUs_per_dp_rank
=
2
DP_size
=
2
def
main
(
dp_size
,
dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
dp_rank
)
def
main
(
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
global_dp_rank
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
os
.
environ
[
"VLLM_DP_MASTER_IP"
]
=
dp_master_ip
os
.
environ
[
"VLLM_DP_MASTER_PORT"
]
=
str
(
dp_master_port
)
# set devices for each dp_rank
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
str
(
i
)
for
i
in
range
(
dp_rank
*
GPUs_per_dp_rank
,
(
dp_rank
+
1
)
*
GPUs_per_dp_rank
))
str
(
i
)
for
i
in
range
(
local_dp_rank
*
GPUs_per_dp_rank
,
(
local_dp_rank
+
1
)
*
GPUs_per_dp_rank
))
# Sample prompts.
prompts
=
[
...
...
@@ -28,20 +51,20 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
]
*
100
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank
=
len
(
prompts
)
//
dp_size
start
=
dp_rank
*
promts_per_rank
start
=
global_
dp_rank
*
promts_per_rank
end
=
start
+
promts_per_rank
prompts
=
prompts
[
start
:
end
]
if
len
(
prompts
)
==
0
:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts
=
[
"Placeholder"
]
print
(
f
"DP rank
{
dp_rank
}
needs to process
{
len
(
prompts
)
}
prompts"
)
print
(
f
"DP rank
{
global_
dp_rank
}
needs to process
{
len
(
prompts
)
}
prompts"
)
# Create a sampling params object.
# since we are doing data parallel, every rank can have different
...
...
@@ -49,31 +72,82 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
# ranks for demonstration.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
16
*
(
dp_rank
+
1
)
)
max_tokens
=
[
16
,
20
][
global_
dp_rank
%
2
]
)
# Create an LLM.
llm
=
LLM
(
model
=
"ibm-research/PowerMoE-3b"
,
llm
=
LLM
(
model
=
model
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
True
,
enable_expert_parallel
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
for
i
,
output
in
enumerate
(
outputs
):
if
i
>=
5
:
# print only 5 outputs
break
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"DP rank
{
dp_rank
}
, Prompt:
{
prompt
!
r
}
, "
print
(
f
"DP rank
{
global_
dp_rank
}
, Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"Data Parallel Inference"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"ibm-research/PowerMoE-3b"
,
help
=
"Model name or path"
)
parser
.
add_argument
(
"--dp-size"
,
type
=
int
,
default
=
2
,
help
=
"Data parallel size"
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
2
,
help
=
"Tensor parallel size"
)
parser
.
add_argument
(
"--node-size"
,
type
=
int
,
default
=
1
,
help
=
"Total number of nodes"
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
default
=
0
,
help
=
"Rank of the current node"
)
parser
.
add_argument
(
"--master-addr"
,
type
=
str
,
default
=
""
,
help
=
"Master node IP address"
)
parser
.
add_argument
(
"--master-port"
,
type
=
int
,
default
=
0
,
help
=
"Master node port"
)
args
=
parser
.
parse_args
()
dp_size
=
args
.
dp_size
tp_size
=
args
.
tp_size
node_size
=
args
.
node_size
node_rank
=
args
.
node_rank
if
node_size
==
1
:
dp_master_ip
=
"127.0.0.1"
dp_master_port
=
get_open_port
()
else
:
dp_master_ip
=
args
.
master_addr
dp_master_port
=
args
.
master_port
assert
dp_size
%
node_size
==
0
,
"dp_size should be divisible by node_size"
dp_per_node
=
dp_size
//
node_size
from
multiprocessing
import
Process
dp_master_ip
=
"127.0.0.1"
dp_master_port
=
get_open_port
()
procs
=
[]
for
i
in
range
(
DP_size
):
for
local_dp_rank
,
global_dp_rank
in
enumerate
(
range
(
node_rank
*
dp_per_node
,
(
node_rank
+
1
)
*
dp_per_node
)):
proc
=
Process
(
target
=
main
,
args
=
(
DP_size
,
i
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
))
args
=
(
args
.
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
tp_size
))
proc
.
start
()
procs
.
append
(
proc
)
exit_code
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment