Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0673969
Unverified
Commit
e0673969
authored
Apr 23, 2025
by
shangmingc
Committed by
GitHub
Apr 23, 2025
Browse files
[PD] Add support for dp attention with mooncake (#5530)
Signed-off-by:
Shangming Cai
<
caishangming@linux.alibaba.com
>
parent
127ff898
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
10 deletions
+99
-10
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+99
-10
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
e0673969
...
...
@@ -109,6 +109,13 @@ class MooncakeKVManager(BaseKVManager):
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
not
server_args
.
enable_dp_attention
and
server_args
.
dp_size
!=
1
:
raise
ValueError
(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
rank_port
=
None
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
...
...
@@ -121,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
start_decode_thread
()
self
.
connection_pool
:
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
prefill_dp_size_table
:
Dict
[
str
,
int
]
=
{}
else
:
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
...
...
@@ -331,6 +339,8 @@ class MooncakeKVManager(BaseKVManager):
url
=
f
"http://
{
bootstrap_server_url
}
/route"
payload
=
{
"role"
:
"Prefill"
,
"tp_size"
:
self
.
tp_size
,
"dp_size"
:
self
.
dp_size
,
"rank_ip"
:
get_local_ip_by_remote
(),
"rank_port"
:
self
.
rank_port
,
"engine_rank"
:
self
.
kv_args
.
engine_rank
,
...
...
@@ -408,12 +418,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
if
not
self
.
kv_mgr
.
enable_dp_attention
:
# We assume dp_attention should be activated simultaneously for
# both prefill role and decode role. If the decode instance does
# not enable dp_attention, then dp_attention is not enabled on the
# prefill instance as well. Therefore, we should skip questioning
# the prefill dp size to reduce bootstrap overhead.
self
.
prefill_dp_size
=
1
elif
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_dp_size
,
tp_size_per_dp_rank
=
(
self
.
_get_prefill_dp_size_from_server
()
)
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank.
assert
tp_size_per_dp_rank
==
self
.
kv_mgr
.
tp_size
//
self
.
kv_mgr
.
dp_size
if
self
.
prefill_dp_size
is
None
:
logger
.
error
(
f
"Could not fetch prefill dp_size for bootstrap_addr:
{
self
.
bootstrap_addr
}
"
)
else
:
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_dp_size
)
else
:
self
.
prefill_dp_size
=
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
# NOTE: key distinguished by bootstrap_addr and engine_rank
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
bootstrap_key
=
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
if
bootstrap_key
not
in
self
.
kv_mgr
.
connection_pool
:
self
.
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
self
.
kv_mgr
.
kv_args
.
engine_rank
self
.
kv_mgr
.
kv_args
.
engine_rank
,
self
.
target_dp_group
,
)
if
self
.
bootstrap_info
is
None
:
logger
.
error
(
...
...
@@ -427,10 +466,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
assert
self
.
bootstrap_info
is
not
None
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
):
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
,
target_dp_group
):
"""Fetch the bootstrap info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
"
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
&target_dp_group=
{
target_dp_group
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
...
...
@@ -444,6 +483,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
def
_get_prefill_dp_size_from_server
(
self
)
->
int
:
"""Fetch the prefill parallel info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
-
1
}
&target_dp_group=
{
-
1
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_parallel_info
=
response
.
json
()
return
int
(
prefill_parallel_info
[
"prefill_dp_size"
]),
int
(
prefill_parallel_info
[
"tp_size_per_dp_rank"
]
)
else
:
logger
.
error
(
f
"Failed to get prefill parallel info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill parallel info from bootstrap:
{
e
}
"
)
return
None
@
classmethod
def
_connect
(
cls
,
endpoint
:
str
):
with
cls
.
_global_lock
:
...
...
@@ -497,7 +555,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
dp_size
=
None
self
.
tp_size_per_dp_rank
=
None
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]]
=
{}
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
...
...
@@ -523,35 +583,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
async
def
_handle_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
role
=
data
[
"role"
]
tp_size
=
data
[
"tp_size"
]
dp_size
=
data
[
"dp_size"
]
rank_ip
=
data
[
"rank_ip"
]
rank_port
=
int
(
data
[
"rank_port"
])
engine_rank
=
int
(
data
[
"engine_rank"
])
if
self
.
dp_size
is
None
:
self
.
dp_size
=
dp_size
tp_size_per_dp_rank
=
tp_size
//
dp_size
if
self
.
tp_size_per_dp_rank
==
None
:
self
.
tp_size_per_dp_rank
=
tp_size_per_dp_rank
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
self
.
prefill_port_table
[
engine_rank
]
=
{
dp_group
=
engine_rank
//
tp_size_per_dp_rank
tp_rank_in_dp_group
=
engine_rank
%
tp_size_per_dp_rank
async
with
self
.
lock
:
if
dp_group
not
in
self
.
prefill_port_table
:
self
.
prefill_port_table
[
dp_group
]
=
{}
self
.
prefill_port_table
[
dp_group
][
tp_rank_in_dp_group
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
}
logger
.
debug
(
f
"Registered Prefill boostrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
"
f
"Registered Prefill boo
t
strap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
if
not
engine_rank
:
return
web
.
Response
(
text
=
"Missing rank"
,
status
=
400
)
target_dp_group
=
request
.
query
.
get
(
"target_dp_group"
)
if
not
engine_rank
or
not
target_dp_group
:
return
web
.
Response
(
text
=
"Missing inputs for bootstrap server."
,
status
=
400
)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if
int
(
engine_rank
)
==
-
1
and
int
(
target_dp_group
)
==
-
1
:
prefill_parallel_info
=
{
"prefill_dp_size"
:
self
.
dp_size
,
"tp_size_per_dp_rank"
:
self
.
tp_size_per_dp_rank
,
}
return
web
.
json_response
(
prefill_parallel_info
,
status
=
200
)
# Find corresponding prefill info
tp_rank_in_dp_group
=
int
(
engine_rank
)
%
self
.
tp_size_per_dp_rank
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
.
get
(
int
(
engine_rank
))
bootstrap_info
=
self
.
prefill_port_table
[
int
(
target_dp_group
)][
tp_rank_in_dp_group
]
if
bootstrap_info
is
not
None
:
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"
N
ot Found"
,
status
=
404
)
return
web
.
Response
(
text
=
"
Bootstrap info n
ot Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment