Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
471650de
Unverified
Commit
471650de
authored
Apr 15, 2025
by
lambert0312
Committed by
GitHub
Apr 15, 2025
Browse files
Fix broadcast use cuda device lead to memory capacity unbalanced (#5416)
parent
d06a83fb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
11 deletions
+35
-11
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+30
-10
python/sglang/srt/entrypoints/verl_engine.py
python/sglang/srt/entrypoints/verl_engine.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-1
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
471650de
...
...
@@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available
logger
=
logging
.
getLogger
(
__name__
)
def
find_available_ports
(
base_port
:
int
,
count
:
int
)
->
List
[
int
]:
"""Find consecutive available ports starting from base_port."""
available_ports
=
[]
...
...
@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return
available_ports
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
...
...
@@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager):
)
if
ret
!=
0
:
self
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
continue
if
kv_chunk
.
is_last
:
...
...
@@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager):
self
.
request_status
[
req
.
room
]
=
(
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
except
queue
.
Empty
:
...
...
@@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
prefill_info
=
response
.
json
()
return
prefill_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
...
...
@@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
if
prefill_info
is
None
:
logger
.
error
(
logger
.
error
(
f
"Could not fetch prefill server info for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
logger
.
error
(
f
"Could not fetch prefill server info for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
)
else
:
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
=
prefill_info
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
=
(
prefill_info
)
else
:
prefill_info
=
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
if
prefill_info
:
self
.
prefill_server_url
=
f
"
{
prefill_info
[
'serve_ip'
]
}
:
{
prefill_info
[
'serve_port'
]
}
"
self
.
prefill_server_url
=
(
f
"
{
prefill_info
[
'serve_ip'
]
}
:
{
prefill_info
[
'serve_port'
]
}
"
)
logger
.
info
(
f
"Fetched prefill server info:
{
prefill_info
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
logger
.
info
(
f
"Fetched prefill server info:
{
prefill_info
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
self
.
handshake_prefill_server
(
kv_indices
,
aux_index
)
def
handshake_prefill_server
(
...
...
@@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
async
with
self
.
lock
:
self
.
prefill_port_table
[
tp_rank
]
=
{
"serve_ip"
:
serve_ip
,
"serve_port"
:
serve_port
}
logger
.
info
(
f
"Registered Prefill tp_rank:
{
tp_rank
}
with serve_ip:
{
serve_ip
}
and serve_port:
{
serve_port
}
"
)
self
.
prefill_port_table
[
tp_rank
]
=
{
"serve_ip"
:
serve_ip
,
"serve_port"
:
serve_port
,
}
logger
.
info
(
f
"Registered Prefill tp_rank:
{
tp_rank
}
with serve_ip:
{
serve_ip
}
and serve_port:
{
serve_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
...
...
python/sglang/srt/entrypoints/verl_engine.py
View file @
471650de
...
...
@@ -118,6 +118,7 @@ class VerlEngine:
rank
=
self
.
_tp_rank
,
dist_group
=
self
.
_device_mesh_cpu
.
get_group
(),
src
=
self
.
_device_mesh_cpu
.
mesh
[
0
].
item
(),
force_cpu_device
=
False
,
)
return
output
...
...
python/sglang/srt/utils.py
View file @
471650de
...
...
@@ -846,9 +846,12 @@ def broadcast_pyobj(
rank
:
int
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
force_cpu_device
:
bool
=
True
,
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
force_cpu_device
else
"cpu"
)
if
rank
==
0
:
if
len
(
data
)
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment