Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3bc43c68
Unverified
Commit
3bc43c68
authored
Jul 15, 2025
by
Qiaolin Yu
Committed by
GitHub
Jul 15, 2025
Browse files
Fix different device type adjustment in PP (#7760)
parent
7498522f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
27 deletions
+25
-27
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+5
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+14
-20
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
3bc43c68
...
@@ -699,14 +699,14 @@ class GroupCoordinator:
...
@@ -699,14 +699,14 @@ class GroupCoordinator:
)
)
# Serialize object to tensor and get the size as well
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
).
cuda
(
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
).
to
(
device
=
torch
.
cuda
.
current_
device
()
device
=
self
.
device
)
)
size_tensor
=
torch
.
tensor
(
size_tensor
=
torch
.
tensor
(
[
object_tensor
.
numel
()],
[
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_
device
()
,
device
=
self
.
device
,
)
)
# Send object size
# Send object size
...
@@ -731,9 +731,7 @@ class GroupCoordinator:
...
@@ -731,9 +731,7 @@ class GroupCoordinator:
src
!=
self
.
rank_in_group
src
!=
self
.
rank_in_group
),
"Invalid source rank. Source rank is the same as the current rank."
),
"Invalid source rank. Source rank is the same as the current rank."
size_tensor
=
torch
.
empty
(
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
1
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
# Receive object size
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
rank_size
=
torch
.
distributed
.
recv
(
...
@@ -744,7 +742,7 @@ class GroupCoordinator:
...
@@ -744,7 +742,7 @@ class GroupCoordinator:
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_
device
()
,
device
=
self
.
device
,
)
)
rank_object
=
torch
.
distributed
.
recv
(
rank_object
=
torch
.
distributed
.
recv
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
3bc43c68
...
@@ -962,6 +962,7 @@ class Scheduler(
...
@@ -962,6 +962,7 @@ class Scheduler(
self
.
world_group
.
device_group
,
self
.
world_group
.
device_group
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
+
1
)
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
+
1
)
*
self
.
tp_size
+
dp_offset
,
device
=
self
.
device
,
)
)
# send out proxy tensors to the next stage
# send out proxy tensors to the next stage
...
@@ -1010,6 +1011,7 @@ class Scheduler(
...
@@ -1010,6 +1011,7 @@ class Scheduler(
self
.
world_group
.
device_group
,
self
.
world_group
.
device_group
,
(
self
.
pp_rank
-
1
)
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
-
1
)
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
device
=
self
.
device
,
)
)
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
...
@@ -1040,6 +1042,7 @@ class Scheduler(
...
@@ -1040,6 +1042,7 @@ class Scheduler(
self
.
attn_tp_group
.
rank
,
self
.
attn_tp_group
.
rank
,
self
.
attn_tp_cpu_group
,
self
.
attn_tp_cpu_group
,
src
=
self
.
attn_tp_group
.
ranks
[
0
],
src
=
self
.
attn_tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
if
self
.
tp_size
!=
1
:
if
self
.
tp_size
!=
1
:
control_reqs
=
broadcast_pyobj
(
control_reqs
=
broadcast_pyobj
(
...
@@ -1047,6 +1050,7 @@ class Scheduler(
...
@@ -1047,6 +1050,7 @@ class Scheduler(
self
.
tp_group
.
rank
,
self
.
tp_group
.
rank
,
self
.
tp_cpu_group
,
self
.
tp_cpu_group
,
src
=
self
.
tp_group
.
ranks
[
0
],
src
=
self
.
tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
recv_reqs
=
work_reqs
+
control_reqs
recv_reqs
=
work_reqs
+
control_reqs
elif
self
.
tp_size
!=
1
:
elif
self
.
tp_size
!=
1
:
...
@@ -1055,6 +1059,7 @@ class Scheduler(
...
@@ -1055,6 +1059,7 @@ class Scheduler(
self
.
tp_group
.
rank
,
self
.
tp_group
.
rank
,
self
.
tp_cpu_group
,
self
.
tp_cpu_group
,
src
=
self
.
tp_group
.
ranks
[
0
],
src
=
self
.
tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
return
recv_reqs
return
recv_reqs
...
...
python/sglang/srt/managers/tp_worker.py
View file @
3bc43c68
...
@@ -144,6 +144,7 @@ class TpModelWorker:
...
@@ -144,6 +144,7 @@ class TpModelWorker:
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
world_group
.
cpu_group
,
self
.
world_group
.
cpu_group
,
src
=
self
.
world_group
.
ranks
[
0
],
src
=
self
.
world_group
.
ranks
[
0
],
device
=
self
.
device
,
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
...
...
python/sglang/srt/utils.py
View file @
3bc43c68
...
@@ -1094,15 +1094,15 @@ def broadcast_pyobj(
...
@@ -1094,15 +1094,15 @@ def broadcast_pyobj(
rank
:
int
,
rank
:
int
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
force_cpu_device
:
bool
=
Tru
e
,
device
:
Optional
[
str
]
=
Non
e
,
):
):
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
of dist_group argument).
"""
"""
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
force_cpu_device
else
"cpu"
if
device
is
None
:
)
device
=
get_device
(
)
if
rank
==
src
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
...
@@ -1142,44 +1142,38 @@ def point_to_point_pyobj(
...
@@ -1142,44 +1142,38 @@ def point_to_point_pyobj(
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
dst
:
int
=
1
,
dst
:
int
=
1
,
device
:
Optional
[
str
]
=
None
,
):
):
"""Send data from src to dst in group using DeviceToDevice communication."""
"""Send data from src to dst in group using DeviceToDevice communication."""
if
device
is
None
:
device
=
get_device
()
if
rank
==
src
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
else
:
else
:
serialized_data
=
pickle
.
dumps
(
data
)
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
tensor_data
=
torch
.
ByteTensor
(
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
).
cuda
(
).
to
(
device
=
torch
.
cuda
.
current_device
()
device
=
device
)
# Move to GPU
)
# Move to Device
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
,
device
=
device
)
[
size
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
return
data
return
data
elif
rank
==
dst
:
elif
rank
==
dst
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
size
=
tensor_size
.
item
()
size
=
tensor_size
.
item
()
if
size
==
0
:
if
size
==
0
:
return
[]
return
[]
tensor_data
=
torch
.
empty
(
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
,
device
=
device
)
size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
serialized_data
=
bytes
(
serialized_data
=
bytes
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment