Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
23cc66f7
Unverified
Commit
23cc66f7
authored
Oct 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 11, 2024
Browse files
Add back data parallelism (#1635)
parent
5d09ca57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
228 additions
and
39 deletions
+228
-39
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+5
-6
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+177
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+33
-20
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-9
No files found.
.github/workflows/pr-test.yml
View file @
23cc66f7
...
@@ -255,12 +255,11 @@ jobs:
...
@@ -255,12 +255,11 @@ jobs:
python3 test_mla.py
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_mla_fp8.py
# Temporarily disabled
-
name
:
Evaluate Data Parallelism Accuracy (DP=2)
#- name: Evaluate Data Parallelism Accuracy (TP=2)
timeout-minutes
:
10
# timeout-minutes: 10
run
:
|
# run: |
cd test/srt
# cd test/srt
python3 test_data_parallelism.py
# python3 test_data_parallelism.py
finish
:
finish
:
needs
:
[
needs
:
[
...
...
python/sglang/bench_latency.py
View file @
23cc66f7
...
@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
...
@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
gpu_id
=
tp_rank
,
gpu_id
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
port_args
.
nccl_port
s
[
0
]
,
nccl_port
=
port_args
.
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
)
)
rank_print
(
f
"max_total_num_tokens=
{
model_runner
.
max_total_num_tokens
}
"
)
rank_print
(
f
"max_total_num_tokens=
{
model_runner
.
max_total_num_tokens
}
"
)
...
...
python/sglang/srt/managers/data_parallel_controller.py
0 → 100644
View file @
23cc66f7
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A controller that dispatches requests to multiple data parallel workers."""
import
logging
import
multiprocessing
as
mp
from
enum
import
Enum
,
auto
import
zmq
from
sglang.srt.managers.io_struct
import
(
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
configure_logger
,
kill_parent_process
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
class
LoadBalanceMethod
(
Enum
):
"""Load balance method."""
ROUND_ROBIN
=
auto
()
SHORTEST_QUEUE
=
auto
()
@
classmethod
def
from_str
(
cls
,
method
:
str
):
method
=
method
.
upper
()
try
:
return
cls
[
method
]
except
KeyError
as
exc
:
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
class
DataParallelController
:
"""A controller that dispatches requests to multiple data parallel workers."""
def
__init__
(
self
,
server_args
,
port_args
)
->
None
:
# Parse args
self
.
server_args
=
server_args
self
.
port_args
=
port_args
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
server_args
.
load_balance_method
)
# Init inter-process communication
self
.
context
=
zmq
.
Context
(
1
+
server_args
.
dp_size
)
self
.
recv_from_tokenizer
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"ipc://
{
port_args
.
scheduler_input_ipc_name
}
"
)
# Dispatch method
self
.
round_robin_counter
=
0
dispatch_lookup
=
{
LoadBalanceMethod
.
ROUND_ROBIN
:
self
.
round_robin_scheduler
,
LoadBalanceMethod
.
SHORTEST_QUEUE
:
self
.
shortest_queue_scheduler
,
}
self
.
dispatching
=
dispatch_lookup
[
self
.
load_balance_method
]
# Start data parallel workers
base_gpu_id
=
0
self
.
workers
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
.
detokenizer_ipc_name
=
port_args
.
detokenizer_ipc_name
send_to
=
self
.
launch_tensor_parallel_group
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
,
)
self
.
workers
.
append
(
send_to
)
base_gpu_id
+=
server_args
.
tp_size
def
launch_tensor_parallel_group
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
base_gpu_id
:
int
,
dp_rank
:
int
,
):
# Launch tensor parallel scheduler processes
scheduler_procs
=
[]
scheduler_pipe_readers
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
tp_rank_range
=
range
(
tp_size_per_node
*
server_args
.
node_rank
,
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
)
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
base_gpu_id
+
tp_rank
%
tp_size_per_node
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
)
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
send_to
=
self
.
context
.
socket
(
zmq
.
PUSH
)
send_to
.
connect
(
f
"ipc://
{
port_args
.
scheduler_input_ipc_name
}
"
)
# Wait for model to finish loading
for
i
in
range
(
len
(
scheduler_pipe_readers
)):
scheduler_pipe_readers
[
i
].
recv
()
return
send_to
def
round_robin_scheduler
(
self
,
req
):
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
def
shortest_queue_scheduler
(
self
,
input_requests
):
raise
NotImplementedError
()
def
event_loop
(
self
):
while
True
:
while
True
:
try
:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
break
if
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
,
),
):
self
.
dispatching
(
recv_req
)
else
:
# Send other control messages to all workers
for
worker
in
self
.
workers
:
worker
.
queue
.
put
(
recv_req
)
def
run_data_parallel_controller_process
(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
):
configure_logger
(
server_args
)
suppress_other_loggers
()
try
:
controller
=
DataParallelController
(
server_args
,
port_args
)
pipe_writer
.
send
(
"ready"
)
controller
.
event_loop
()
except
Exception
:
msg
=
get_exception_traceback
()
logger
.
error
(
msg
)
kill_parent_process
()
python/sglang/srt/managers/scheduler.py
View file @
23cc66f7
...
@@ -142,7 +142,7 @@ class Scheduler:
...
@@ -142,7 +142,7 @@ class Scheduler:
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_port
s
[
0
]
,
nccl_port
=
port_args
.
nccl_port
,
)
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
...
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
...
@@ -1042,9 +1042,14 @@ def run_scheduler_process(
port_args
:
PortArgs
,
port_args
:
PortArgs
,
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
pipe_writer
,
pipe_writer
,
):
):
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
if
dp_rank
is
None
:
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
else
:
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
suppress_other_loggers
()
suppress_other_loggers
()
try
:
try
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
23cc66f7
...
@@ -141,7 +141,7 @@ class ModelRunner:
...
@@ -141,7 +141,7 @@ class ModelRunner:
self
.
init_attention_backend
()
self
.
init_attention_backend
()
def
init_torch_distributed
(
self
):
def
init_torch_distributed
(
self
):
logger
.
info
(
"Init torch distributed
begin."
)
logger
.
info
(
"Init torch distributed begin."
)
# Init torch distributed
# Init torch distributed
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
...
...
python/sglang/srt/server.py
View file @
23cc66f7
...
@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
...
@@ -44,6 +44,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.data_parallel_controller
import
(
run_data_parallel_controller_process
,
)
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -337,30 +340,40 @@ def launch_engine(
...
@@ -337,30 +340,40 @@ def launch_engine(
server_args
.
model_path
,
server_args
.
tokenizer_path
server_args
.
model_path
,
server_args
.
tokenizer_path
)
)
# Launch tensor parallel scheduler processes
if
server_args
.
dp_size
==
1
:
scheduler_procs
=
[]
# Launch tensor parallel scheduler processes
scheduler_pipe_readers
=
[]
scheduler_procs
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
scheduler_pipe_readers
=
[]
tp_rank_range
=
range
(
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
tp_size_per_node
*
server_args
.
node_rank
,
tp_rank_range
=
range
(
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
tp_size_per_node
*
server_args
.
node_rank
,
)
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
for
tp_rank
in
tp_rank_range
:
)
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
tp_rank
%
tp_size_per_node
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
)
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
if
server_args
.
node_rank
>=
1
:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while
True
:
pass
else
:
# Launch the data parallel controller
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
tp_rank
%
tp_size_per_node
scheduler_pipe_readers
=
[
reader
]
proc
=
mp
.
Process
(
proc
=
mp
.
Process
(
target
=
run_
schedu
ler_process
,
target
=
run_
data_parallel_control
ler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
writer
),
args
=
(
server_args
,
port_args
,
writer
),
)
)
proc
.
start
()
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
if
server_args
.
node_rank
>=
1
:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while
True
:
pass
# Launch detokenizer process
# Launch detokenizer process
detoken_proc
=
mp
.
Process
(
detoken_proc
=
mp
.
Process
(
...
...
python/sglang/srt/server_args.py
View file @
23cc66f7
...
@@ -574,7 +574,7 @@ class ServerArgs:
...
@@ -574,7 +574,7 @@ class ServerArgs:
self
.
tp_size
%
self
.
nnodes
==
0
self
.
tp_size
%
self
.
nnodes
==
0
),
"tp_size must be divisible by number of nodes"
),
"tp_size must be divisible by number of nodes"
assert
not
(
assert
not
(
self
.
dp_size
>
1
and
self
.
node
_rank
is
not
None
self
.
dp_size
>
1
and
self
.
n
node
s
!=
1
),
"multi-node data parallel is not supported"
),
"multi-node data parallel is not supported"
assert
(
assert
(
self
.
max_loras_per_batch
>
0
self
.
max_loras_per_batch
>
0
...
@@ -583,11 +583,6 @@ class ServerArgs:
...
@@ -583,11 +583,6 @@ class ServerArgs:
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
),
"compatibility of lora and cuda graph and radix attention is in progress"
),
"compatibility of lora and cuda graph and radix attention is in progress"
assert
self
.
dp_size
==
1
,
(
"The support for data parallelism is temporarily disabled during refactor. "
"Please use sglang<=0.3.2 or wait for later updates."
)
if
isinstance
(
self
.
lora_paths
,
list
):
if
isinstance
(
self
.
lora_paths
,
list
):
lora_paths
=
self
.
lora_paths
lora_paths
=
self
.
lora_paths
self
.
lora_paths
=
{}
self
.
lora_paths
=
{}
...
@@ -626,8 +621,8 @@ class PortArgs:
...
@@ -626,8 +621,8 @@ class PortArgs:
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name
:
str
detokenizer_ipc_name
:
str
# The port for nccl initialization
for multiple TP groups
(torch.dist)
# The port for nccl initialization (torch.dist)
nccl_port
s
:
List
[
int
]
nccl_port
:
int
@
staticmethod
@
staticmethod
def
init_new
(
server_args
)
->
"PortArgs"
:
def
init_new
(
server_args
)
->
"PortArgs"
:
...
@@ -641,7 +636,7 @@ class PortArgs:
...
@@ -641,7 +636,7 @@ class PortArgs:
tokenizer_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
tokenizer_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
scheduler_input_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
scheduler_input_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
detokenizer_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
detokenizer_ipc_name
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
,
nccl_port
s
=
[
port
]
,
nccl_port
=
port
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment