Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e806f708
Unverified
Commit
e806f708
authored
May 27, 2025
by
Trevor Morris
Committed by
GitHub
May 27, 2025
Browse files
[PD] Make bootstrap code common between NIXL and Mooncake (#6473)
parent
fa6723f0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
595 additions
and
428 deletions
+595
-428
docs/backend/pd_disaggregation.md
docs/backend/pd_disaggregation.md
+41
-0
python/sglang/srt/disaggregation/common/__init__.py
python/sglang/srt/disaggregation/common/__init__.py
+1
-0
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+401
-0
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+4
-18
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+130
-409
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+18
-1
No files found.
docs/backend/pd_disaggregation.md
View file @
e806f708
...
...
@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
--disaggregation-ib-device
${
device_name
}
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
## NIXL
### Requirements
Install via pip.
```
bash
pip
install
nixl
```
Or build from source - may be required if you already have UCX installed.
```
bash
git clone https://github.com/ai-dynamo/nixl.git
cd
nixl
pip
install
.
--config-settings
=
setup-args
=
"-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.srt.disaggregation.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
### DeepSeek Multi-Node
```
bash
# prefill 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# prefill 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# decode 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
python/sglang/srt/disaggregation/common/__init__.py
0 → 100644
View file @
e806f708
from
.conn
import
CommonKVBootstrapServer
,
CommonKVManager
,
CommonKVReceiver
python/sglang/srt/disaggregation/common/conn.py
0 → 100644
View file @
e806f708
from
__future__
import
annotations
import
asyncio
import
logging
import
socket
import
threading
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
import
requests
import
zmq
from
aiohttp
import
web
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_free_port
,
get_ip
,
get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
class
CommonKVManager
(
BaseKVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
,
server_args
:
ServerArgs
,
is_mla_backend
:
Optional
[
bool
]
=
False
,
):
self
.
kv_args
=
args
self
.
is_mla_backend
=
is_mla_backend
self
.
disaggregation_mode
=
disaggregation_mode
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
not
server_args
.
enable_dp_attention
and
server_args
.
dp_size
!=
1
:
raise
ValueError
(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self
.
rank_port
=
get_free_port
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
_register_to_bootstrap
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
connection_pool
:
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
prefill_tp_size_table
:
Dict
[
str
,
int
]
=
{}
self
.
prefill_dp_size_table
:
Dict
[
str
,
int
]
=
{}
else
:
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
ip_address
=
socket
.
gethostbyname
(
self
.
dist_init_addr
.
split
(
":"
)[
0
])
else
:
ip_address
=
get_ip
()
bootstrap_server_url
=
f
"
{
ip_address
}
:
{
self
.
bootstrap_port
}
"
url
=
f
"http://
{
bootstrap_server_url
}
/route"
payload
=
{
"role"
:
"Prefill"
,
"tp_size"
:
self
.
tp_size
,
"dp_size"
:
self
.
dp_size
,
"rank_ip"
:
get_local_ip_by_remote
(),
"rank_port"
:
self
.
rank_port
,
"engine_rank"
:
self
.
kv_args
.
engine_rank
,
}
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
debug
(
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
error
(
f
"Prefill Failed to connect to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
socket
.
connect
(
endpoint
)
return
socket
class
CommonKVReceiver
(
BaseKVReceiver
):
_ctx
=
zmq
.
Context
()
_socket_cache
=
{}
_socket_locks
=
{}
_global_lock
=
threading
.
Lock
()
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
_get_prefill_dp_size_from_server
()
)
if
self
.
prefill_tp_size
is
None
or
self
.
prefill_dp_size
is
None
:
logger
.
error
(
f
"Could not fetch prefill parallel info for bootstrap_addr:
{
self
.
bootstrap_addr
}
"
)
else
:
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_tp_size
)
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_dp_size
)
else
:
self
.
prefill_tp_size
=
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
self
.
prefill_dp_size
=
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank
=
self
.
kv_mgr
.
tp_size
//
self
.
kv_mgr
.
dp_size
prefill_tp_size_per_dp_rank
=
self
.
prefill_tp_size
//
self
.
prefill_dp_size
if
local_tp_size_per_dp_rank
==
prefill_tp_size_per_dp_rank
:
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
1
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
elif
local_tp_size_per_dp_rank
>
prefill_tp_size_per_dp_rank
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
//
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
else
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self
.
target_tp_ranks
=
[
rank
for
rank
in
range
(
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
+
1
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
target_dp_group
}
_
{
self
.
target_tp_rank
}
"
)
if
bootstrap_key
not
in
self
.
kv_mgr
.
connection_pool
:
bootstrap_infos
=
[]
for
target_tp_rank
in
self
.
target_tp_ranks
:
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
target_tp_rank
,
self
.
target_dp_group
,
)
if
bootstrap_info
is
not
None
:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info
[
"is_dummy"
]
=
not
bool
(
target_tp_rank
==
self
.
target_tp_rank
or
self
.
target_tp_rank
is
None
)
bootstrap_infos
.
append
(
bootstrap_info
)
else
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
and target_dp_group:
{
self
.
target_dp_group
}
"
)
self
.
bootstrap_infos
=
bootstrap_infos
if
len
(
self
.
bootstrap_infos
)
==
0
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
else
:
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
=
self
.
bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self
.
_register_kv_args
()
else
:
self
.
bootstrap_infos
=
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
assert
len
(
self
.
bootstrap_infos
)
>
0
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
,
target_dp_group
):
"""Fetch the bootstrap info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
&target_dp_group=
{
target_dp_group
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
return
bootstrap_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
def
_get_prefill_dp_size_from_server
(
self
)
->
int
:
"""Fetch the prefill parallel info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
-
1
}
&target_dp_group=
{
-
1
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_parallel_info
=
response
.
json
()
return
int
(
prefill_parallel_info
[
"prefill_tp_size"
]),
int
(
prefill_parallel_info
[
"prefill_dp_size"
]
)
else
:
logger
.
error
(
f
"Failed to get prefill parallel info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill parallel info from bootstrap:
{
e
}
"
)
return
None
@
classmethod
def
_connect
(
cls
,
endpoint
:
str
):
with
cls
.
_global_lock
:
if
endpoint
not
in
cls
.
_socket_cache
:
sock
=
cls
.
_ctx
.
socket
(
zmq
.
PUSH
)
sock
.
connect
(
endpoint
)
cls
.
_socket_cache
[
endpoint
]
=
sock
cls
.
_socket_locks
[
endpoint
]
=
threading
.
Lock
()
return
cls
.
_socket_cache
[
endpoint
],
cls
.
_socket_locks
[
endpoint
]
def
_register_kv_args
(
self
):
pass
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
class
CommonKVBootstrapServer
(
BaseKVBootstrapServer
):
def
__init__
(
self
,
port
:
int
):
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
tp_size
=
None
self
.
dp_size
=
None
self
.
tp_size_per_dp_rank
=
None
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]]
=
{}
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
self
.
run
()
def
run
(
self
):
self
.
thread
.
start
()
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/route"
,
self
.
_handle_route
)
async
def
_handle_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
role
=
data
[
"role"
]
tp_size
=
data
[
"tp_size"
]
dp_size
=
data
[
"dp_size"
]
rank_ip
=
data
[
"rank_ip"
]
rank_port
=
int
(
data
[
"rank_port"
])
engine_rank
=
int
(
data
[
"engine_rank"
])
if
self
.
tp_size
is
None
:
self
.
tp_size
=
tp_size
if
self
.
dp_size
is
None
:
self
.
dp_size
=
dp_size
tp_size_per_dp_rank
=
tp_size
//
dp_size
if
self
.
tp_size_per_dp_rank
==
None
:
self
.
tp_size_per_dp_rank
=
tp_size_per_dp_rank
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
dp_group
=
engine_rank
//
tp_size_per_dp_rank
tp_rank_in_dp_group
=
engine_rank
%
tp_size_per_dp_rank
async
with
self
.
lock
:
if
dp_group
not
in
self
.
prefill_port_table
:
self
.
prefill_port_table
[
dp_group
]
=
{}
self
.
prefill_port_table
[
dp_group
][
tp_rank_in_dp_group
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
}
logger
.
debug
(
f
"Register Prefill bootstrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
target_dp_group
=
request
.
query
.
get
(
"target_dp_group"
)
if
not
engine_rank
or
not
target_dp_group
:
return
web
.
Response
(
text
=
"Missing inputs for bootstrap server."
,
status
=
400
)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if
int
(
engine_rank
)
==
-
1
and
int
(
target_dp_group
)
==
-
1
:
prefill_parallel_info
=
{
"prefill_tp_size"
:
self
.
tp_size
,
"prefill_dp_size"
:
self
.
dp_size
,
}
return
web
.
json_response
(
prefill_parallel_info
,
status
=
200
)
# Find corresponding prefill info
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
[
int
(
target_dp_group
)][
int
(
engine_rank
)
]
if
bootstrap_info
is
not
None
:
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Bootstrap info not Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
# Event Loop
self
.
_loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
self
.
_loop
)
self
.
_runner
=
web
.
AppRunner
(
self
.
app
)
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
setup
())
site
=
web
.
TCPSite
(
self
.
_runner
,
port
=
self
.
port
)
self
.
_loop
.
run_until_complete
(
site
.
start
())
self
.
_loop
.
run_forever
()
except
Exception
as
e
:
logger
.
error
(
f
"Server error:
{
str
(
e
)
}
"
)
finally
:
# Cleanup
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
cleanup
())
self
.
_loop
.
close
()
def
close
(
self
):
"""Shutdown"""
if
self
.
_loop
is
not
None
and
self
.
_loop
.
is_running
():
self
.
_loop
.
call_soon_threadsafe
(
self
.
_loop
.
stop
)
logger
.
info
(
"Stopping server loop..."
)
if
self
.
thread
.
is_alive
():
self
.
thread
.
join
(
timeout
=
2
)
logger
.
info
(
"Server thread stopped"
)
def
poll
(
self
)
->
KVPoll
:
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
e806f708
...
...
@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
group_concurrent_contiguous
,
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_free_port
,
...
...
@@ -41,23 +44,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
__name__
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
class
KVTransferError
(
Exception
):
def
__init__
(
self
,
bootstrap_room
:
int
,
failure_reason
:
str
):
super
().
__init__
(
failure_reason
)
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
e806f708
...
...
@@ -18,40 +18,23 @@ import requests
import
zmq
from
aiohttp
import
web
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
from
sglang.srt.disaggregation.base.conn
import
BaseKVSender
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.common.conn
import
(
CommonKVBootstrapServer
,
CommonKVManager
,
CommonKVReceiver
,
)
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
group_concurrent_contiguous
,
)
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_free_port
,
get_ip
,
get_local_ip_by_remote
from
sglang.srt.utils
import
get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
NixlEngineInfo
:
TypeAlias
=
Dict
[
str
,
Union
[
str
,
int
]]
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
GUARD
=
"NixlMsgGuard"
.
encode
(
"ascii"
)
...
...
@@ -61,11 +44,13 @@ class TransferInfo:
endpoint
:
str
dst_port
:
int
agent_metadata
:
bytes
agent_name
:
str
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_index
:
int
dst_gpu_id
:
int
required_dst_info_num
:
int
def
is_dummy
(
self
):
return
self
.
endpoint
==
""
...
...
@@ -79,11 +64,13 @@ class TransferInfo:
endpoint
=
""
,
dst_port
=
0
,
agent_metadata
=
b
""
,
agent_name
=
""
,
dst_kv_ptrs
=
[],
dst_kv_indices
=
np
.
array
([],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
[],
dst_aux_index
=
0
,
dst_gpu_id
=
0
,
required_dst_info_num
=
0
,
)
else
:
return
cls
(
...
...
@@ -91,11 +78,13 @@ class TransferInfo:
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
agent_metadata
=
msg
[
3
],
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
4
])
//
8
}
Q"
,
msg
[
4
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
5
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_index
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
agent_name
=
msg
[
4
].
decode
(
"ascii"
),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
6
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
7
])
//
8
}
Q"
,
msg
[
7
])),
dst_aux_index
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
9
].
decode
(
"ascii"
)),
required_dst_info_num
=
int
(
msg
[
10
].
decode
(
"ascii"
)),
)
...
...
@@ -116,7 +105,7 @@ class TransferStatus:
return
self
.
num_kvs_expected
==
len
(
self
.
received_kvs
)
and
self
.
received_aux
class
NixlKVManager
(
Base
KVManager
):
class
NixlKVManager
(
Common
KVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
...
...
@@ -124,6 +113,7 @@ class NixlKVManager(BaseKVManager):
server_args
:
ServerArgs
,
is_mla_backend
:
Optional
[
bool
]
=
False
,
):
super
().
__init__
(
args
,
disaggregation_mode
,
server_args
,
is_mla_backend
)
try
:
from
nixl._api
import
nixl_agent
except
ImportError
as
e
:
...
...
@@ -133,38 +123,15 @@ class NixlKVManager(BaseKVManager):
"to run SGLang with NixlTransferEngine."
)
from
e
self
.
agent
=
nixl_agent
(
str
(
uuid
.
uuid4
()))
self
.
kv_args
=
args
self
.
disaggregation_mode
=
disaggregation_mode
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_rank
=
args
.
engine_rank
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
self
.
enable_dp_attention
:
assert
(
server_args
.
dp_size
>
1
),
"If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
self
.
dp_size
=
server_args
.
dp_size
self
.
tp_size_of_dp
=
server_args
.
tp_size
//
server_args
.
dp_size
self
.
attn_tp_rank
=
args
.
engine_rank
%
self
.
tp_size_of_dp
self
.
dp_rank
=
args
.
engine_rank
//
self
.
tp_size_of_dp
self
.
rank_port
=
None
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
self
.
rank_port
=
get_free_port
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
request_status
=
{}
self
.
transfer_infos
:
Dict
[
int
,
TransferInfo
]
=
{}
self
.
condition
=
threading
.
Condition
()
self
.
peer_names
:
Dict
[
int
,
str
]
=
{}
self
.
peer_names
:
Dict
[
str
,
str
]
=
{}
self
.
_start_bootstrap_thread
()
self
.
_register_to_bootstrap
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# bootstrap key -> (remote_engine_rank -> possible remote source info)
self
.
prefill_peer_infos
:
Dict
[
str
,
list
[
Dict
[
int
,
NixlEngineInfo
]]]
=
{}
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
TransferStatus
)
...
...
@@ -173,6 +140,18 @@ class NixlKVManager(BaseKVManager):
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
def
check_status
(
self
,
bootstrap_room
:
int
):
return
self
.
request_status
[
bootstrap_room
]
def
update_status
(
self
,
bootstrap_room
:
int
,
status
:
KVPoll
):
if
bootstrap_room
not
in
self
.
request_status
:
self
.
request_status
[
bootstrap_room
]
=
status
else
:
# NOTE: The prefill engine could recv bootstrapping first
self
.
request_status
[
bootstrap_room
]
=
max
(
self
.
request_status
[
bootstrap_room
],
status
)
def
register_buffer_to_engine
(
self
):
kv_addrs
=
[]
for
kv_data_ptr
,
kv_data_len
in
zip
(
...
...
@@ -193,16 +172,10 @@ class NixlKVManager(BaseKVManager):
if
not
self
.
aux_descs
:
raise
Exception
(
"NIXL memory registration failed for aux tensors"
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
socket
.
connect
(
endpoint
)
return
socket
def
_add_remote
(
self
,
room
:
int
,
agent_metadata
:
bytes
):
if
room
not
in
self
.
peer_names
:
self
.
peer_names
[
room
]
=
self
.
agent
.
add_remote_agent
(
agent_metadata
)
return
self
.
peer_names
[
room
]
def
_add_remote
(
self
,
agent_name
:
str
,
agent_metadata
:
bytes
):
if
agent_name
not
in
self
.
peer_names
:
self
.
peer_names
[
agent_name
]
=
self
.
agent
.
add_remote_agent
(
agent_metadata
)
return
self
.
peer_names
[
agent_name
]
def
send_kvcache
(
self
,
...
...
@@ -300,40 +273,38 @@ class NixlKVManager(BaseKVManager):
assert
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
assert
not
is_last
or
(
is_last
and
aux_index
is
not
None
)
# Wait for transfer info to be populated by bootstrap thread.
with
self
.
condition
:
self
.
condition
.
wait_for
(
lambda
:
bootstrap_room
in
self
.
transfer_infos
)
req
=
self
.
transfer_infos
[
bootstrap_room
]
assert
bootstrap_room
==
req
.
room
if
req
.
is_dummy
():
return
[]
reqs_to_be_processed
=
self
.
transfer_infos
[
bootstrap_room
].
values
()
handles
=
[]
for
req
in
reqs_to_be_processed
:
assert
bootstrap_room
==
req
.
room
if
req
.
is_dummy
():
return
[]
peer_name
=
self
.
_add_remote
(
bootstrap_room
,
req
.
agent_metadata
)
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_indices
)
peer_name
=
self
.
_add_remote
(
req
.
agent_name
,
req
.
agent_metadata
)
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_indices
)
notif
=
"_"
.
join
([
str
(
req
.
room
),
"kv"
,
str
(
chunk_id
),
str
(
int
(
is_last
))])
kv_xfer_handle
=
self
.
send_kvcache
(
peer_name
,
kv_indices
,
req
.
dst_kv_ptrs
,
chunked_dst_kv_indice
,
req
.
dst_gpu_id
,
notif
,
)
handles
=
[
kv_xfer_handle
]
# Only the last chunk we need to send the aux data.
if
is_last
:
assert
aux_index
is
not
None
aux_xfer_handle
=
self
.
send_aux
(
notif
=
"_"
.
join
([
str
(
req
.
room
),
"kv"
,
str
(
chunk_id
),
str
(
int
(
is_last
))])
kv_xfer_handle
=
self
.
send_kvcache
(
peer_name
,
aux_index
,
req
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
str
(
req
.
room
)
+
"_aux"
,
kv_indices
,
req
.
dst_kv_ptrs
,
chunked_dst_kv_indice
,
req
.
dst_gpu_id
,
notif
,
)
handles
.
append
(
aux_xfer_handle
)
handles
.
append
(
kv_xfer_handle
)
# Only the last chunk we need to send the aux data.
if
is_last
:
assert
aux_index
is
not
None
aux_xfer_handle
=
self
.
send_aux
(
peer_name
,
aux_index
,
req
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
str
(
req
.
room
)
+
"_aux"
,
)
handles
.
append
(
aux_xfer_handle
)
return
handles
def
update_transfer_status
(
self
):
...
...
@@ -348,7 +319,7 @@ class NixlKVManager(BaseKVManager):
room
=
int
(
components
[
0
])
if
components
[
1
]
==
"kv"
:
chunk_id
=
int
(
components
[
2
])
is_last
=
bool
(
components
[
3
])
is_last
=
bool
(
int
(
components
[
3
])
)
self
.
transfer_statuses
[
room
].
received_kvs
.
add
(
chunk_id
)
if
is_last
:
self
.
transfer_statuses
[
room
].
num_kvs_expected
=
chunk_id
+
1
...
...
@@ -360,34 +331,6 @@ class NixlKVManager(BaseKVManager):
return
False
return
self
.
transfer_statuses
[
room
].
is_done
()
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
ip_address
=
socket
.
gethostbyname
(
self
.
dist_init_addr
.
split
(
":"
)[
0
])
else
:
ip_address
=
get_ip
()
bootstrap_server_url
=
f
"
{
ip_address
}
:
{
self
.
bootstrap_port
}
"
url
=
f
"http://
{
bootstrap_server_url
}
/route"
payload
=
{
"role"
:
"Prefill"
,
"rank_ip"
:
get_local_ip_by_remote
(),
"rank_port"
:
self
.
rank_port
,
"engine_rank"
:
self
.
kv_args
.
engine_rank
,
"agent_name"
:
self
.
agent
.
name
,
}
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
debug
(
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
error
(
f
"Prefill Failed to connect to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
def
_start_bootstrap_thread
(
self
):
self
.
server_socket
.
bind
(
f
"tcp://
{
get_local_ip_by_remote
()
}
:
{
self
.
rank_port
}
"
)
...
...
@@ -405,10 +348,19 @@ class NixlKVManager(BaseKVManager):
room
=
waiting_req_bytes
[
0
].
decode
(
"ascii"
)
if
room
==
"None"
:
continue
required_dst_info_num
=
int
(
waiting_req_bytes
[
10
].
decode
(
"ascii"
))
room
=
int
(
room
)
with
self
.
condition
:
self
.
transfer_infos
[
room
]
=
TransferInfo
.
from_zmq
(
waiting_req_bytes
)
self
.
condition
.
notify_all
()
agent_name
=
waiting_req_bytes
[
4
].
decode
(
"ascii"
)
if
room
not
in
self
.
transfer_infos
:
self
.
transfer_infos
[
room
]
=
{}
self
.
transfer_infos
[
room
][
agent_name
]
=
TransferInfo
.
from_zmq
(
waiting_req_bytes
)
logger
.
debug
(
f
"got info
{
room
=
}
{
agent_name
=
}
{
required_dst_info_num
=
}
"
)
if
len
(
self
.
transfer_infos
[
room
])
==
required_dst_info_num
:
logger
.
debug
(
f
"
{
room
=
}
is bootstrapped"
)
self
.
update_status
(
room
,
KVPoll
.
WaitingForInput
)
threading
.
Thread
(
target
=
bootstrap_thread
).
start
()
...
...
@@ -423,6 +375,9 @@ class NixlKVSender(BaseKVSender):
self
.
xfer_handles
=
[]
self
.
has_sent
=
False
self
.
chunk_id
=
0
self
.
kv_mgr
.
update_status
(
self
.
bootstrap_room
,
KVPoll
.
Bootstrapping
)
# inner state
self
.
curr_idx
=
0
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
self
.
num_kv_indices
=
num_kv_indices
...
...
@@ -431,9 +386,11 @@ class NixlKVSender(BaseKVSender):
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
index_slice
:
slice
,
is_last
:
bool
,
):
index_slice
=
slice
(
self
.
curr_idx
,
self
.
curr_idx
+
len
(
kv_indices
))
self
.
curr_idx
+=
len
(
kv_indices
)
is_last
=
self
.
curr_idx
==
self
.
num_kv_indices
new_xfer_handles
=
self
.
kv_mgr
.
add_transfer_request
(
self
.
bootstrap_room
,
kv_indices
,
...
...
@@ -449,7 +406,7 @@ class NixlKVSender(BaseKVSender):
def
poll
(
self
)
->
KVPoll
:
if
not
self
.
has_sent
:
return
KVPoll
.
WaitingForInput
# type: ignore
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
states
=
[
self
.
kv_mgr
.
agent
.
check_xfer_state
(
x
)
for
x
in
self
.
xfer_handles
]
if
all
([
x
==
"DONE"
for
x
in
states
]):
return
KVPoll
.
Success
# type: ignore
...
...
@@ -461,128 +418,40 @@ class NixlKVSender(BaseKVSender):
raise
Exception
(
"Fake KVSender Exception"
)
class
NixlKVReceiver
(
BaseKVReceiver
):
class
NixlKVReceiver
(
CommonKVReceiver
):
def
__init__
(
self
,
mgr
:
NixlKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
self
.
started_transfer
=
False
# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key
=
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
if
bootstrap_key
not
in
self
.
kv_mgr
.
prefill_peer_infos
:
self
.
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
self
.
kv_mgr
.
kv_args
.
engine_rank
)
if
self
.
bootstrap_info
is
None
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
else
:
self
.
kv_mgr
.
prefill_peer_infos
[
bootstrap_key
]
=
self
.
bootstrap_info
else
:
self
.
bootstrap_info
=
self
.
kv_mgr
.
prefill_peer_infos
[
bootstrap_key
]
assert
self
.
bootstrap_info
is
not
None
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
# In each dict, there are multiple possible remotes named "equal sources".
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
)
->
Optional
[
List
[
Dict
[
int
,
NixlEngineInfo
]]]:
"""Fetch the bootstrap info from the bootstrap server."""
try
:
if
self
.
kv_mgr
.
enable_dp_attention
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route"
response
=
requests
.
get
(
url
)
if
response
.
status_code
!=
200
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
bootstrap_info
=
response
.
json
()
assert
isinstance
(
bootstrap_info
,
dict
)
bootstrap_info
=
{
int
(
k
):
v
for
k
,
v
in
bootstrap_info
.
items
()}
# split out who need to send to this rank.
# currently for dpsk mla model, those ranks share the same latent cache.
# pick one as the real source
prefill_tp_size
=
len
(
bootstrap_info
.
keys
())
assert
(
prefill_tp_size
>=
self
.
kv_mgr
.
tp_size_of_dp
),
f
"Only support Prefill TP size >= Decode TP size of DP, now we have
{
prefill_tp_size
}
vs
{
self
.
kv_mgr
.
tp_size_of_dp
}
"
num_remote_tp_rank_we_managed
=
(
prefill_tp_size
//
self
.
kv_mgr
.
tp_size_of_dp
)
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
remote_tp_ranks
=
list
(
range
(
0
,
prefill_tp_size
))
# split it into tp_size_of_dp parts and get our part
remote_tp_ranks_grouped
=
[
remote_tp_ranks
[
i
:
i
+
num_remote_tp_rank_we_managed
]
for
i
in
range
(
0
,
prefill_tp_size
,
self
.
kv_mgr
.
tp_size_of_dp
)
]
managed_ranks
=
remote_tp_ranks_grouped
[
self
.
kv_mgr
.
attn_tp_rank
]
assert
len
(
managed_ranks
)
==
num_remote_tp_rank_we_managed
logger
.
debug
(
f
"Rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
source can be
{
managed_ranks
}
"
)
return
[
{
rk
:
bootstrap_info
[
rk
]
for
rk
in
bootstrap_info
.
keys
()
if
rk
in
managed_ranks
}
]
else
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
return
[{
engine_rank
:
bootstrap_info
}]
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
socket
.
connect
(
endpoint
)
return
socket
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
assert
self
.
bootstrap_info
is
not
None
assert
self
.
bootstrap_room
is
not
None
for
equal_sources
in
self
.
bootstrap_info
:
remote_rank
=
list
(
equal_sources
.
keys
())[
self
.
bootstrap_room
%
len
(
equal_sources
)
]
self
.
prefill_server_url
=
f
"
{
equal_sources
[
remote_rank
][
'rank_ip'
]
}
:
{
equal_sources
[
remote_rank
][
'rank_port'
]
}
"
for
bootstrap_info
in
self
.
bootstrap_infos
:
self
.
prefill_server_url
=
(
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
)
logger
.
debug
(
f
"Fetched bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
, source:
{
remote_rank
}
, all:
{
list
(
equal_sources
.
keys
())
}
"
f
"Fetched bootstrap info
:
{
bootstrap_info
}
for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
is_dummy
=
bootstrap_info
[
"is_dummy"
]
# TODO: just send "" for indices for dummy
if
is_dummy
:
# TODO: need to set success??
sock
,
lock
=
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
)
with
lock
:
sock
.
send_multipart
(
[
GUARD
,
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
]
)
continue
# TODO: send_kv_args earlier
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
...
...
@@ -593,30 +462,22 @@ class NixlKVReceiver(BaseKVReceiver):
logger
.
debug
(
f
"Sending to
{
self
.
prefill_server_url
}
with bootstrap room
{
self
.
bootstrap_room
}
"
)
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
[
GUARD
,
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
packed_kv_data_ptrs
,
kv_indices
.
tobytes
(),
packed_aux_data_ptrs
,
str
(
aux_index
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
]
)
for
dummy_rank
in
equal_sources
.
keys
():
if
dummy_rank
==
remote_rank
:
continue
dummy_info
=
equal_sources
[
dummy_rank
]
dummy_url
=
f
"
{
dummy_info
[
'rank_ip'
]
}
:
{
dummy_info
[
'rank_port'
]
}
"
self
.
_connect
(
"tcp://"
+
dummy_url
).
send_multipart
(
sock
,
lock
=
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
)
with
lock
:
sock
.
send_multipart
(
[
GUARD
,
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
self
.
kv_mgr
.
agent
.
name
.
encode
(
"ascii"
),
packed_kv_data_ptrs
,
kv_indices
.
tobytes
(),
packed_aux_data_ptrs
,
str
(
aux_index
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
str
(
self
.
required_dst_info_num
).
encode
(
"ascii"
),
]
)
...
...
@@ -632,152 +493,12 @@ class NixlKVReceiver(BaseKVReceiver):
return
KVPoll
.
Success
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
def
_register_kv_args
(
self
):
pass
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
class
NixlKVBootstrapServer
(
BaseKVBootstrapServer
):
def
__init__
(
self
,
port
:
int
):
logger
.
debug
(
f
"NixlKVBootstrapServer started on port
{
port
}
"
)
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
self
.
run
()
def
run
(
self
):
self
.
thread
.
start
()
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/metadata"
,
self
.
_handle_metadata
)
self
.
app
.
router
.
add_route
(
"*"
,
"/route"
,
self
.
_handle_route
)
async
def
_handle_metadata
(
self
,
request
:
web
.
Request
):
key
=
request
.
query
.
get
(
"key"
,
""
)
if
request
.
method
==
"GET"
:
return
await
self
.
_handle_metadata_get
(
key
)
elif
request
.
method
==
"PUT"
:
return
await
self
.
_handle_metadata_put
(
key
,
request
)
elif
request
.
method
==
"DELETE"
:
return
await
self
.
_handle_metadata_delete
(
key
)
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_metadata_get
(
self
,
key
):
async
with
self
.
lock
:
value
=
self
.
store
.
get
(
key
)
if
value
is
None
:
return
web
.
Response
(
text
=
"metadata not found"
,
status
=
404
,
content_type
=
"application/json"
)
return
web
.
Response
(
body
=
value
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_metadata_put
(
self
,
key
,
request
):
data
=
await
request
.
read
()
async
with
self
.
lock
:
self
.
store
[
key
]
=
data
return
web
.
Response
(
text
=
"metadata updated"
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_metadata_delete
(
self
,
key
):
async
with
self
.
lock
:
if
key
not
in
self
.
store
:
return
web
.
Response
(
text
=
"metadata not found"
,
status
=
404
,
content_type
=
"application/json"
,
)
del
self
.
store
[
key
]
return
web
.
Response
(
text
=
"metadata deleted"
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
role
=
data
[
"role"
]
rank_ip
=
data
[
"rank_ip"
]
rank_port
=
int
(
data
[
"rank_port"
])
engine_rank
=
int
(
data
[
"engine_rank"
])
agent_name
=
data
[
"agent_name"
]
if
role
==
"Prefill"
:
async
with
self
.
lock
:
self
.
prefill_port_table
[
engine_rank
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
"agent_name"
:
agent_name
,
}
logger
.
info
(
f
"Registered Prefill boostrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
and name:
{
agent_name
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
if
not
engine_rank
:
logger
.
debug
(
f
"No engine_rank specified, return all
{
len
(
self
.
prefill_port_table
)
}
engine infos as a dict"
)
# Return a dict of all engine_rank
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
# Find corresponding prefill info
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
.
get
(
int
(
engine_rank
))
if
bootstrap_info
is
not
None
:
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Not Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
# Event Loop
self
.
_loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
self
.
_loop
)
self
.
_runner
=
web
.
AppRunner
(
self
.
app
)
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
setup
())
site
=
web
.
TCPSite
(
self
.
_runner
,
port
=
self
.
port
)
self
.
_loop
.
run_until_complete
(
site
.
start
())
self
.
_loop
.
run_forever
()
except
Exception
as
e
:
logger
.
error
(
f
"Server error:
{
str
(
e
)
}
"
)
finally
:
# Cleanup
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
cleanup
())
self
.
_loop
.
close
()
def
close
(
self
):
"""Shutdown"""
if
self
.
_loop
is
not
None
and
self
.
_loop
.
is_running
():
self
.
_loop
.
call_soon_threadsafe
(
self
.
_loop
.
stop
)
logger
.
info
(
"Stopping server loop..."
)
if
self
.
thread
.
is_alive
():
self
.
thread
.
join
(
timeout
=
2
)
logger
.
info
(
"Server thread stopped"
)
def
poll
(
self
)
->
KVPoll
:
...
class
NixlKVBootstrapServer
(
CommonKVBootstrapServer
):
pass
python/sglang/srt/disaggregation/utils.py
View file @
e806f708
...
...
@@ -13,7 +13,7 @@ import requests
import
torch
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
from
sglang.srt.utils
import
get_ip
,
get_local_ip_by_remote
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -279,3 +279,20 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment