Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5493c334
Unverified
Commit
5493c334
authored
Mar 17, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 17, 2025
Browse files
Fix data parallel + tensor parallel (#4499)
parent
f2ab37e5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
16 deletions
+53
-16
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-0
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+14
-3
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+30
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+4
-1
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+3
-3
No files found.
.github/workflows/pr-test.yml
View file @
5493c334
...
@@ -290,6 +290,7 @@ jobs:
...
@@ -290,6 +290,7 @@ jobs:
python3 test_moe_eval_accuracy_large.py
python3 test_moe_eval_accuracy_large.py
finish
:
finish
:
if
:
always()
needs
:
[
needs
:
[
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
...
...
python/sglang/srt/layers/dp_attention.py
View file @
5493c334
...
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
...
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return
attn_tp_rank
,
attn_tp_size
,
dp_rank
return
attn_tp_rank
,
attn_tp_size
,
dp_rank
def
initialize_dp_attention
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
initialize_dp_attention
(
enable_dp_attention
:
bool
,
tp_rank
:
int
,
tp_size
:
int
,
dp_size
:
int
,
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
...
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
...
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
=
compute_dp_attention_world_info
(
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
=
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
)
)
_DP_SIZE
=
dp_size
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
tp_group
=
get_tp_group
()
_ATTN_TP_GROUP
=
GroupCoordinator
(
_ATTN_TP_GROUP
=
GroupCoordinator
(
...
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
...
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
],
],
tp
_rank
,
local
_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
5493c334
...
@@ -82,10 +82,12 @@ class DataParallelController:
...
@@ -82,10 +82,12 @@ class DataParallelController:
self
.
scheduler_procs
=
[]
self
.
scheduler_procs
=
[]
self
.
workers
=
[
None
]
*
server_args
.
dp_size
self
.
workers
=
[
None
]
*
server_args
.
dp_size
if
not
server_args
.
enable_dp_attention
:
if
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
else
:
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
server_args
.
tp_size
else
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if
server_args
.
node_rank
==
0
:
if
server_args
.
node_rank
==
0
:
...
@@ -105,6 +107,7 @@ class DataParallelController:
...
@@ -105,6 +107,7 @@ class DataParallelController:
threads
=
[]
threads
=
[]
sockets
=
[]
sockets
=
[]
dp_port_args
=
[]
dp_port_args
=
[]
ready_events
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
for
dp_rank
in
range
(
server_args
.
dp_size
):
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
...
@@ -115,10 +118,13 @@ class DataParallelController:
...
@@ -115,10 +118,13 @@ class DataParallelController:
# We hold it first so that the next dp worker gets a different port
# We hold it first so that the next dp worker gets a different port
sockets
.
append
(
bind_port
(
tmp_port_args
.
nccl_port
))
sockets
.
append
(
bind_port
(
tmp_port_args
.
nccl_port
))
ready_event
=
threading
.
Event
()
ready_events
.
append
(
ready_event
)
# Create a thread for each worker
# Create a thread for each worker
thread
=
threading
.
Thread
(
thread
=
threading
.
Thread
(
target
=
self
.
launch_tensor_parallel_group
,
target
=
self
.
launch_tensor_parallel_group
_thread
,
args
=
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
),
args
=
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
,
ready_event
),
)
)
threads
.
append
(
thread
)
threads
.
append
(
thread
)
base_gpu_id
+=
server_args
.
tp_size
*
server_args
.
gpu_id_step
base_gpu_id
+=
server_args
.
tp_size
*
server_args
.
gpu_id_step
...
@@ -130,11 +136,27 @@ class DataParallelController:
...
@@ -130,11 +136,27 @@ class DataParallelController:
# Start all threads
# Start all threads
for
thread
in
threads
:
for
thread
in
threads
:
thread
.
start
()
thread
.
start
()
for
thread
in
th
reads
:
for
event
in
read
y_event
s
:
thread
.
join
()
event
.
wait
()
return
dp_port_args
return
dp_port_args
def
launch_tensor_parallel_group_thread
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
base_gpu_id
:
int
,
dp_rank
:
int
,
ready_event
:
threading
.
Event
,
):
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
base_gpu_id
,
dp_rank
)
ready_event
.
set
()
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
# function in scheduler.py will kill the scheduler.
while
True
:
pass
def
launch_dp_attention_schedulers
(
self
,
server_args
,
port_args
):
def
launch_dp_attention_schedulers
(
self
,
server_args
,
port_args
):
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
)
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
)
dp_port_args
=
[]
dp_port_args
=
[]
...
@@ -223,7 +245,7 @@ class DataParallelController:
...
@@ -223,7 +245,7 @@ class DataParallelController:
self
.
dispatching
(
recv_req
)
self
.
dispatching
(
recv_req
)
else
:
else
:
# Send other control messages to first worker of tp group
# Send other control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
server_args
.
tp_size
]:
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
worker
.
send_pyobj
(
recv_req
)
worker
.
send_pyobj
(
recv_req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
5493c334
...
@@ -1786,7 +1786,7 @@ def run_scheduler_process(
...
@@ -1786,7 +1786,7 @@ def run_scheduler_process(
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
# Config the process
# Config the process
#
kill_itself_when_parent_died()
# This is disabled because it does not work for `--dp 2`
kill_itself_when_parent_died
()
setproctitle
.
setproctitle
(
f
"sglang::scheduler
{
prefix
.
replace
(
' '
,
'_'
)
}
"
)
setproctitle
.
setproctitle
(
f
"sglang::scheduler
{
prefix
.
replace
(
' '
,
'_'
)
}
"
)
faulthandler
.
enable
()
faulthandler
.
enable
()
parent_process
=
psutil
.
Process
().
parent
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5493c334
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
bisect
import
bisect
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Callable
...
@@ -81,7 +82,9 @@ def patch_model(
...
@@ -81,7 +82,9 @@ def patch_model(
# tp_group.ca_comm = None
# tp_group.ca_comm = None
yield
torch
.
compile
(
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
,
mode
=
os
.
environ
.
get
(
"SGLANG_TORCH_COMPILE_MODE"
,
"max-autotune-no-cudagraphs"
),
dynamic
=
False
,
dynamic
=
False
,
)
)
else
:
else
:
...
...
test/srt/test_data_parallelism.py
View file @
5493c334
...
@@ -23,7 +23,7 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -23,7 +23,7 @@ class TestDataParallelism(unittest.TestCase):
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--dp"
,
"2"
],
other_args
=
[
"--dp"
,
2
],
)
)
@
classmethod
@
classmethod
...
@@ -52,7 +52,7 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -52,7 +52,7 @@ class TestDataParallelism(unittest.TestCase):
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
# pause a few seconds then send again
# pause a few seconds then send again
time
.
sleep
(
5
)
time
.
sleep
(
1
)
response
=
requests
.
post
(
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights_from_disk"
,
self
.
base_url
+
"/update_weights_from_disk"
,
...
@@ -67,7 +67,7 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestDataParallelism(unittest.TestCase):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
time
.
sleep
(
1
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment