Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5493c334
"vscode:/vscode.git/clone" did not exist on "12cb760a3773fe1a97d5a00fca26412f814f20fa"
Unverified
Commit
5493c334
authored
Mar 17, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 17, 2025
Browse files
Fix data parallel + tensor parallel (#4499)
parent
f2ab37e5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
16 deletions
+53
-16
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-0
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+14
-3
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+30
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+4
-1
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+3
-3
No files found.
.github/workflows/pr-test.yml
View file @
5493c334
...
...
@@ -290,6 +290,7 @@ jobs:
python3 test_moe_eval_accuracy_large.py
finish
:
if
:
always()
needs
:
[
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
...
...
python/sglang/srt/layers/dp_attention.py
View file @
5493c334
...
...
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return
attn_tp_rank
,
attn_tp_size
,
dp_rank
def
initialize_dp_attention
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
initialize_dp_attention
(
enable_dp_attention
:
bool
,
tp_rank
:
int
,
tp_size
:
int
,
dp_size
:
int
,
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
...
...
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
=
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
)
_DP_SIZE
=
dp_size
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
_ATTN_TP_GROUP
=
GroupCoordinator
(
...
...
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
],
tp
_rank
,
local
_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
5493c334
...
...
@@ -82,10 +82,12 @@ class DataParallelController:
self
.
scheduler_procs
=
[]
self
.
workers
=
[
None
]
*
server_args
.
dp_size
if
not
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
else
:
if
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
server_args
.
tp_size
else
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if
server_args
.
node_rank
==
0
:
...
...
@@ -105,6 +107,7 @@ class DataParallelController:
threads
=
[]
sockets
=
[]
dp_port_args
=
[]
ready_events
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
...
...
@@ -115,10 +118,13 @@ class DataParallelController:
# We hold it first so that the next dp worker gets a different port
sockets
.
append
(
bind_port
(
tmp_port_args
.
nccl_port
))
ready_event
=
threading
.
Event
()
ready_events
.
append
(
ready_event
)
# Create a thread for each worker
thread
=
threading
.
Thread
(
target
=
self
.
launch_tensor_parallel_group
,
args
=
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
),
target
=
self
.
launch_tensor_parallel_group
_thread
,
args
=
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
,
ready_event
),
)
threads
.
append
(
thread
)
base_gpu_id
+=
server_args
.
tp_size
*
server_args
.
gpu_id_step
...
...
@@ -130,11 +136,27 @@ class DataParallelController:
# Start all threads
for
thread
in
threads
:
thread
.
start
()
for
thread
in
th
reads
:
thread
.
join
()
for
event
in
read
y_event
s
:
event
.
wait
()
return
dp_port_args
def
launch_tensor_parallel_group_thread
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
base_gpu_id
:
int
,
dp_rank
:
int
,
ready_event
:
threading
.
Event
,
):
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
base_gpu_id
,
dp_rank
)
ready_event
.
set
()
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
# function in scheduler.py will kill the scheduler.
while
True
:
pass
def
launch_dp_attention_schedulers
(
self
,
server_args
,
port_args
):
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
)
dp_port_args
=
[]
...
...
@@ -223,7 +245,7 @@ class DataParallelController:
self
.
dispatching
(
recv_req
)
else
:
# Send other control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
server_args
.
tp_size
]:
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
worker
.
send_pyobj
(
recv_req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
5493c334
...
...
@@ -1786,7 +1786,7 @@ def run_scheduler_process(
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
# Config the process
#
kill_itself_when_parent_died()
# This is disabled because it does not work for `--dp 2`
kill_itself_when_parent_died
()
setproctitle
.
setproctitle
(
f
"sglang::scheduler
{
prefix
.
replace
(
' '
,
'_'
)
}
"
)
faulthandler
.
enable
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5493c334
...
...
@@ -16,6 +16,7 @@
from
__future__
import
annotations
import
bisect
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
...
...
@@ -81,7 +82,9 @@ def patch_model(
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
,
mode
=
os
.
environ
.
get
(
"SGLANG_TORCH_COMPILE_MODE"
,
"max-autotune-no-cudagraphs"
),
dynamic
=
False
,
)
else
:
...
...
test/srt/test_data_parallelism.py
View file @
5493c334
...
...
@@ -23,7 +23,7 @@ class TestDataParallelism(unittest.TestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--dp"
,
"2"
],
other_args
=
[
"--dp"
,
2
],
)
@
classmethod
...
...
@@ -52,7 +52,7 @@ class TestDataParallelism(unittest.TestCase):
assert
response
.
status_code
==
200
# pause a few seconds then send again
time
.
sleep
(
5
)
time
.
sleep
(
1
)
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights_from_disk"
,
...
...
@@ -67,7 +67,7 @@ class TestDataParallelism(unittest.TestCase):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
time
.
sleep
(
1
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment