Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e806f708
Unverified
Commit
e806f708
authored
May 27, 2025
by
Trevor Morris
Committed by
GitHub
May 27, 2025
Browse files
[PD] Make bootstrap code common between NIXL and Mooncake (#6473)
parent
fa6723f0
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
595 additions
and
428 deletions
+595
-428
docs/backend/pd_disaggregation.md
docs/backend/pd_disaggregation.md
+41
-0
python/sglang/srt/disaggregation/common/__init__.py
python/sglang/srt/disaggregation/common/__init__.py
+1
-0
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+401
-0
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+4
-18
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+130
-409
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+18
-1
No files found.
docs/backend/pd_disaggregation.md
View file @
e806f708
...
...
@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
--disaggregation-ib-device
${
device_name
}
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
## NIXL
### Requirements
Install via pip.
```
bash
pip
install
nixl
```
Or build from source - may be required if you already have UCX installed.
```
bash
git clone https://github.com/ai-dynamo/nixl.git
cd
nixl
pip
install
.
--config-settings
=
setup-args
=
"-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.srt.disaggregation.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
### DeepSeek Multi-Node
```
bash
# prefill 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# prefill 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# decode 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
python/sglang/srt/disaggregation/common/__init__.py
0 → 100644
View file @
e806f708
from
.conn
import
CommonKVBootstrapServer
,
CommonKVManager
,
CommonKVReceiver
python/sglang/srt/disaggregation/common/conn.py
0 → 100644
View file @
e806f708
from
__future__
import
annotations
import
asyncio
import
logging
import
socket
import
threading
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
import
requests
import
zmq
from
aiohttp
import
web
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_free_port
,
get_ip
,
get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
class
CommonKVManager
(
BaseKVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
,
server_args
:
ServerArgs
,
is_mla_backend
:
Optional
[
bool
]
=
False
,
):
self
.
kv_args
=
args
self
.
is_mla_backend
=
is_mla_backend
self
.
disaggregation_mode
=
disaggregation_mode
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
not
server_args
.
enable_dp_attention
and
server_args
.
dp_size
!=
1
:
raise
ValueError
(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self
.
rank_port
=
get_free_port
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
_register_to_bootstrap
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
connection_pool
:
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
prefill_tp_size_table
:
Dict
[
str
,
int
]
=
{}
self
.
prefill_dp_size_table
:
Dict
[
str
,
int
]
=
{}
else
:
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
ip_address
=
socket
.
gethostbyname
(
self
.
dist_init_addr
.
split
(
":"
)[
0
])
else
:
ip_address
=
get_ip
()
bootstrap_server_url
=
f
"
{
ip_address
}
:
{
self
.
bootstrap_port
}
"
url
=
f
"http://
{
bootstrap_server_url
}
/route"
payload
=
{
"role"
:
"Prefill"
,
"tp_size"
:
self
.
tp_size
,
"dp_size"
:
self
.
dp_size
,
"rank_ip"
:
get_local_ip_by_remote
(),
"rank_port"
:
self
.
rank_port
,
"engine_rank"
:
self
.
kv_args
.
engine_rank
,
}
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
debug
(
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
error
(
f
"Prefill Failed to connect to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
socket
.
connect
(
endpoint
)
return
socket
class
CommonKVReceiver
(
BaseKVReceiver
):
_ctx
=
zmq
.
Context
()
_socket_cache
=
{}
_socket_locks
=
{}
_global_lock
=
threading
.
Lock
()
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
_get_prefill_dp_size_from_server
()
)
if
self
.
prefill_tp_size
is
None
or
self
.
prefill_dp_size
is
None
:
logger
.
error
(
f
"Could not fetch prefill parallel info for bootstrap_addr:
{
self
.
bootstrap_addr
}
"
)
else
:
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_tp_size
)
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_dp_size
)
else
:
self
.
prefill_tp_size
=
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
self
.
prefill_dp_size
=
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank
=
self
.
kv_mgr
.
tp_size
//
self
.
kv_mgr
.
dp_size
prefill_tp_size_per_dp_rank
=
self
.
prefill_tp_size
//
self
.
prefill_dp_size
if
local_tp_size_per_dp_rank
==
prefill_tp_size_per_dp_rank
:
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
1
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
elif
local_tp_size_per_dp_rank
>
prefill_tp_size_per_dp_rank
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
//
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
else
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self
.
target_tp_ranks
=
[
rank
for
rank
in
range
(
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
+
1
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
target_dp_group
}
_
{
self
.
target_tp_rank
}
"
)
if
bootstrap_key
not
in
self
.
kv_mgr
.
connection_pool
:
bootstrap_infos
=
[]
for
target_tp_rank
in
self
.
target_tp_ranks
:
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
target_tp_rank
,
self
.
target_dp_group
,
)
if
bootstrap_info
is
not
None
:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info
[
"is_dummy"
]
=
not
bool
(
target_tp_rank
==
self
.
target_tp_rank
or
self
.
target_tp_rank
is
None
)
bootstrap_infos
.
append
(
bootstrap_info
)
else
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
and target_dp_group:
{
self
.
target_dp_group
}
"
)
self
.
bootstrap_infos
=
bootstrap_infos
if
len
(
self
.
bootstrap_infos
)
==
0
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
else
:
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
=
self
.
bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self
.
_register_kv_args
()
else
:
self
.
bootstrap_infos
=
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
assert
len
(
self
.
bootstrap_infos
)
>
0
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
,
target_dp_group
):
"""Fetch the bootstrap info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
&target_dp_group=
{
target_dp_group
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
return
bootstrap_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
def
_get_prefill_dp_size_from_server
(
self
)
->
int
:
"""Fetch the prefill parallel info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
-
1
}
&target_dp_group=
{
-
1
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_parallel_info
=
response
.
json
()
return
int
(
prefill_parallel_info
[
"prefill_tp_size"
]),
int
(
prefill_parallel_info
[
"prefill_dp_size"
]
)
else
:
logger
.
error
(
f
"Failed to get prefill parallel info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill parallel info from bootstrap:
{
e
}
"
)
return
None
@
classmethod
def
_connect
(
cls
,
endpoint
:
str
):
with
cls
.
_global_lock
:
if
endpoint
not
in
cls
.
_socket_cache
:
sock
=
cls
.
_ctx
.
socket
(
zmq
.
PUSH
)
sock
.
connect
(
endpoint
)
cls
.
_socket_cache
[
endpoint
]
=
sock
cls
.
_socket_locks
[
endpoint
]
=
threading
.
Lock
()
return
cls
.
_socket_cache
[
endpoint
],
cls
.
_socket_locks
[
endpoint
]
def
_register_kv_args
(
self
):
pass
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
class
CommonKVBootstrapServer
(
BaseKVBootstrapServer
):
def
__init__
(
self
,
port
:
int
):
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
tp_size
=
None
self
.
dp_size
=
None
self
.
tp_size_per_dp_rank
=
None
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]]
=
{}
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
self
.
run
()
def
run
(
self
):
self
.
thread
.
start
()
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/route"
,
self
.
_handle_route
)
async
def
_handle_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
role
=
data
[
"role"
]
tp_size
=
data
[
"tp_size"
]
dp_size
=
data
[
"dp_size"
]
rank_ip
=
data
[
"rank_ip"
]
rank_port
=
int
(
data
[
"rank_port"
])
engine_rank
=
int
(
data
[
"engine_rank"
])
if
self
.
tp_size
is
None
:
self
.
tp_size
=
tp_size
if
self
.
dp_size
is
None
:
self
.
dp_size
=
dp_size
tp_size_per_dp_rank
=
tp_size
//
dp_size
if
self
.
tp_size_per_dp_rank
==
None
:
self
.
tp_size_per_dp_rank
=
tp_size_per_dp_rank
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
dp_group
=
engine_rank
//
tp_size_per_dp_rank
tp_rank_in_dp_group
=
engine_rank
%
tp_size_per_dp_rank
async
with
self
.
lock
:
if
dp_group
not
in
self
.
prefill_port_table
:
self
.
prefill_port_table
[
dp_group
]
=
{}
self
.
prefill_port_table
[
dp_group
][
tp_rank_in_dp_group
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
}
logger
.
debug
(
f
"Register Prefill bootstrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
target_dp_group
=
request
.
query
.
get
(
"target_dp_group"
)
if
not
engine_rank
or
not
target_dp_group
:
return
web
.
Response
(
text
=
"Missing inputs for bootstrap server."
,
status
=
400
)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if
int
(
engine_rank
)
==
-
1
and
int
(
target_dp_group
)
==
-
1
:
prefill_parallel_info
=
{
"prefill_tp_size"
:
self
.
tp_size
,
"prefill_dp_size"
:
self
.
dp_size
,
}
return
web
.
json_response
(
prefill_parallel_info
,
status
=
200
)
# Find corresponding prefill info
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
[
int
(
target_dp_group
)][
int
(
engine_rank
)
]
if
bootstrap_info
is
not
None
:
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Bootstrap info not Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
# Event Loop
self
.
_loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
self
.
_loop
)
self
.
_runner
=
web
.
AppRunner
(
self
.
app
)
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
setup
())
site
=
web
.
TCPSite
(
self
.
_runner
,
port
=
self
.
port
)
self
.
_loop
.
run_until_complete
(
site
.
start
())
self
.
_loop
.
run_forever
()
except
Exception
as
e
:
logger
.
error
(
f
"Server error:
{
str
(
e
)
}
"
)
finally
:
# Cleanup
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
cleanup
())
self
.
_loop
.
close
()
def
close
(
self
):
"""Shutdown"""
if
self
.
_loop
is
not
None
and
self
.
_loop
.
is_running
():
self
.
_loop
.
call_soon_threadsafe
(
self
.
_loop
.
stop
)
logger
.
info
(
"Stopping server loop..."
)
if
self
.
thread
.
is_alive
():
self
.
thread
.
join
(
timeout
=
2
)
logger
.
info
(
"Server thread stopped"
)
def
poll
(
self
)
->
KVPoll
:
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
e806f708
...
...
@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
group_concurrent_contiguous
,
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_free_port
,
...
...
@@ -41,23 +44,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
__name__
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
class
KVTransferError
(
Exception
):
def
__init__
(
self
,
bootstrap_room
:
int
,
failure_reason
:
str
):
super
().
__init__
(
failure_reason
)
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
e806f708
This diff is collapsed.
Click to expand it.
python/sglang/srt/disaggregation/utils.py
View file @
e806f708
...
...
@@ -13,7 +13,7 @@ import requests
import
torch
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
from
sglang.srt.utils
import
get_ip
,
get_local_ip_by_remote
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -279,3 +279,20 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment