Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
3bc43c68
"vscode:/vscode.git/clone" did not exist on "dbb2434f5d2d976be26b594342a68cb46619ecea"
Unverified
Commit
3bc43c68
authored
Jul 15, 2025
by
Qiaolin Yu
Committed by
GitHub
Jul 15, 2025
Browse files
Fix different device type adjustment in PP (#7760)
parent
7498522f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
27 deletions
+25
-27
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+5
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+14
-20
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
3bc43c68
...
@@ -699,14 +699,14 @@ class GroupCoordinator:
...
@@ -699,14 +699,14 @@ class GroupCoordinator:
)
)
# Serialize object to tensor and get the size as well
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
).
cuda
(
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
).
to
(
device
=
torch
.
cuda
.
current_
device
()
device
=
self
.
device
)
)
size_tensor
=
torch
.
tensor
(
size_tensor
=
torch
.
tensor
(
[
object_tensor
.
numel
()],
[
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_
device
()
,
device
=
self
.
device
,
)
)
# Send object size
# Send object size
...
@@ -731,9 +731,7 @@ class GroupCoordinator:
...
@@ -731,9 +731,7 @@ class GroupCoordinator:
src
!=
self
.
rank_in_group
src
!=
self
.
rank_in_group
),
"Invalid source rank. Source rank is the same as the current rank."
),
"Invalid source rank. Source rank is the same as the current rank."
size_tensor
=
torch
.
empty
(
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
1
,
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
# Receive object size
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
rank_size
=
torch
.
distributed
.
recv
(
...
@@ -744,7 +742,7 @@ class GroupCoordinator:
...
@@ -744,7 +742,7 @@ class GroupCoordinator:
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_
device
()
,
device
=
self
.
device
,
)
)
rank_object
=
torch
.
distributed
.
recv
(
rank_object
=
torch
.
distributed
.
recv
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
3bc43c68
...
@@ -962,6 +962,7 @@ class Scheduler(
...
@@ -962,6 +962,7 @@ class Scheduler(
self
.
world_group
.
device_group
,
self
.
world_group
.
device_group
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
+
1
)
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
+
1
)
*
self
.
tp_size
+
dp_offset
,
device
=
self
.
device
,
)
)
# send out proxy tensors to the next stage
# send out proxy tensors to the next stage
...
@@ -1010,6 +1011,7 @@ class Scheduler(
...
@@ -1010,6 +1011,7 @@ class Scheduler(
self
.
world_group
.
device_group
,
self
.
world_group
.
device_group
,
(
self
.
pp_rank
-
1
)
*
self
.
tp_size
+
dp_offset
,
(
self
.
pp_rank
-
1
)
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
self
.
pp_rank
*
self
.
tp_size
+
dp_offset
,
device
=
self
.
device
,
)
)
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
...
@@ -1040,6 +1042,7 @@ class Scheduler(
...
@@ -1040,6 +1042,7 @@ class Scheduler(
self
.
attn_tp_group
.
rank
,
self
.
attn_tp_group
.
rank
,
self
.
attn_tp_cpu_group
,
self
.
attn_tp_cpu_group
,
src
=
self
.
attn_tp_group
.
ranks
[
0
],
src
=
self
.
attn_tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
if
self
.
tp_size
!=
1
:
if
self
.
tp_size
!=
1
:
control_reqs
=
broadcast_pyobj
(
control_reqs
=
broadcast_pyobj
(
...
@@ -1047,6 +1050,7 @@ class Scheduler(
...
@@ -1047,6 +1050,7 @@ class Scheduler(
self
.
tp_group
.
rank
,
self
.
tp_group
.
rank
,
self
.
tp_cpu_group
,
self
.
tp_cpu_group
,
src
=
self
.
tp_group
.
ranks
[
0
],
src
=
self
.
tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
recv_reqs
=
work_reqs
+
control_reqs
recv_reqs
=
work_reqs
+
control_reqs
elif
self
.
tp_size
!=
1
:
elif
self
.
tp_size
!=
1
:
...
@@ -1055,6 +1059,7 @@ class Scheduler(
...
@@ -1055,6 +1059,7 @@ class Scheduler(
self
.
tp_group
.
rank
,
self
.
tp_group
.
rank
,
self
.
tp_cpu_group
,
self
.
tp_cpu_group
,
src
=
self
.
tp_group
.
ranks
[
0
],
src
=
self
.
tp_group
.
ranks
[
0
],
device
=
self
.
device
,
)
)
return
recv_reqs
return
recv_reqs
...
...
python/sglang/srt/managers/tp_worker.py
View file @
3bc43c68
...
@@ -144,6 +144,7 @@ class TpModelWorker:
...
@@ -144,6 +144,7 @@ class TpModelWorker:
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
world_group
.
cpu_group
,
self
.
world_group
.
cpu_group
,
src
=
self
.
world_group
.
ranks
[
0
],
src
=
self
.
world_group
.
ranks
[
0
],
device
=
self
.
device
,
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
...
...
python/sglang/srt/utils.py
View file @
3bc43c68
...
@@ -1094,15 +1094,15 @@ def broadcast_pyobj(
...
@@ -1094,15 +1094,15 @@ def broadcast_pyobj(
rank
:
int
,
rank
:
int
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
force_cpu_device
:
bool
=
Tru
e
,
device
:
Optional
[
str
]
=
Non
e
,
):
):
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
of dist_group argument).
"""
"""
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
force_cpu_device
else
"cpu"
if
device
is
None
:
)
device
=
get_device
(
)
if
rank
==
src
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
...
@@ -1142,44 +1142,38 @@ def point_to_point_pyobj(
...
@@ -1142,44 +1142,38 @@ def point_to_point_pyobj(
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
dst
:
int
=
1
,
dst
:
int
=
1
,
device
:
Optional
[
str
]
=
None
,
):
):
"""Send data from src to dst in group using DeviceToDevice communication."""
"""Send data from src to dst in group using DeviceToDevice communication."""
if
device
is
None
:
device
=
get_device
()
if
rank
==
src
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
else
:
else
:
serialized_data
=
pickle
.
dumps
(
data
)
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
tensor_data
=
torch
.
ByteTensor
(
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
).
cuda
(
).
to
(
device
=
torch
.
cuda
.
current_device
()
device
=
device
)
# Move to GPU
)
# Move to Device
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
,
device
=
device
)
[
size
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
return
data
return
data
elif
rank
==
dst
:
elif
rank
==
dst
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
size
=
tensor_size
.
item
()
size
=
tensor_size
.
item
()
if
size
==
0
:
if
size
==
0
:
return
[]
return
[]
tensor_data
=
torch
.
empty
(
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
,
device
=
device
)
size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
serialized_data
=
bytes
(
serialized_data
=
bytes
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment