Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
e806f708
"train.py" did not exist on "de076fe9f46d8c27c705f823b0c55fb144f1ab3a"
Unverified
Commit
e806f708
authored
May 27, 2025
by
Trevor Morris
Committed by
GitHub
May 27, 2025
Browse files
[PD] Make bootstrap code common between NIXL and Mooncake (#6473)
parent
fa6723f0
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
595 additions
and
428 deletions
+595
-428
docs/backend/pd_disaggregation.md
docs/backend/pd_disaggregation.md
+41
-0
python/sglang/srt/disaggregation/common/__init__.py
python/sglang/srt/disaggregation/common/__init__.py
+1
-0
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+401
-0
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+4
-18
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+130
-409
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+18
-1
No files found.
docs/backend/pd_disaggregation.md
View file @
e806f708
...
...
@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
--disaggregation-ib-device
${
device_name
}
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
## NIXL
### Requirements
Install via pip.
```
bash
pip
install
nixl
```
Or build from source - may be required if you already have UCX installed.
```
bash
git clone https://github.com/ai-dynamo/nixl.git
cd
nixl
pip
install
.
--config-settings
=
setup-args
=
"-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.srt.disaggregation.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
### DeepSeek Multi-Node
```
bash
# prefill 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# prefill 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
prefill
--host
${
local_ip
}
--port
30000
--trust-remote-code
--dist-init-addr
${
prefill_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
normal
--mem-fraction-static
0.8
# decode 0
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
0
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
# decode 1
$
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3-0324
---disaggregation-transfer-backend
nixl
--disaggregation-mode
decode
--host
${
local_ip
}
--port
30001
--trust-remote-code
--dist-init-addr
${
decode_master_ip
}
:5000
--nnodes
2
--node-rank
1
--tp-size
16
--dp-size
8
--enable-dp-attention
--enable-deepep-moe
--deepep-mode
low_latency
--mem-fraction-static
0.8
--max-running-requests
128
```
python/sglang/srt/disaggregation/common/__init__.py
0 → 100644
View file @
e806f708
from
.conn
import
CommonKVBootstrapServer
,
CommonKVManager
,
CommonKVReceiver
python/sglang/srt/disaggregation/common/conn.py
0 → 100644
View file @
e806f708
from
__future__
import
annotations
import
asyncio
import
logging
import
socket
import
threading
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
import
requests
import
zmq
from
aiohttp
import
web
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_free_port
,
get_ip
,
get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
class
CommonKVManager
(
BaseKVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
,
server_args
:
ServerArgs
,
is_mla_backend
:
Optional
[
bool
]
=
False
,
):
self
.
kv_args
=
args
self
.
is_mla_backend
=
is_mla_backend
self
.
disaggregation_mode
=
disaggregation_mode
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
not
server_args
.
enable_dp_attention
and
server_args
.
dp_size
!=
1
:
raise
ValueError
(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self
.
rank_port
=
get_free_port
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
_register_to_bootstrap
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
connection_pool
:
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
prefill_tp_size_table
:
Dict
[
str
,
int
]
=
{}
self
.
prefill_dp_size_table
:
Dict
[
str
,
int
]
=
{}
else
:
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
if
self
.
dist_init_addr
:
ip_address
=
socket
.
gethostbyname
(
self
.
dist_init_addr
.
split
(
":"
)[
0
])
else
:
ip_address
=
get_ip
()
bootstrap_server_url
=
f
"
{
ip_address
}
:
{
self
.
bootstrap_port
}
"
url
=
f
"http://
{
bootstrap_server_url
}
/route"
payload
=
{
"role"
:
"Prefill"
,
"tp_size"
:
self
.
tp_size
,
"dp_size"
:
self
.
dp_size
,
"rank_ip"
:
get_local_ip_by_remote
(),
"rank_port"
:
self
.
rank_port
,
"engine_rank"
:
self
.
kv_args
.
engine_rank
,
}
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
debug
(
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
error
(
f
"Prefill Failed to connect to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
socket
=
zmq
.
Context
().
socket
(
zmq
.
PUSH
)
socket
.
connect
(
endpoint
)
return
socket
class
CommonKVReceiver
(
BaseKVReceiver
):
_ctx
=
zmq
.
Context
()
_socket_cache
=
{}
_socket_locks
=
{}
_global_lock
=
threading
.
Lock
()
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
_get_prefill_dp_size_from_server
()
)
if
self
.
prefill_tp_size
is
None
or
self
.
prefill_dp_size
is
None
:
logger
.
error
(
f
"Could not fetch prefill parallel info for bootstrap_addr:
{
self
.
bootstrap_addr
}
"
)
else
:
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_tp_size
)
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
=
(
self
.
prefill_dp_size
)
else
:
self
.
prefill_tp_size
=
self
.
kv_mgr
.
prefill_tp_size_table
[
self
.
bootstrap_addr
]
self
.
prefill_dp_size
=
self
.
kv_mgr
.
prefill_dp_size_table
[
self
.
bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank
=
self
.
kv_mgr
.
tp_size
//
self
.
kv_mgr
.
dp_size
prefill_tp_size_per_dp_rank
=
self
.
prefill_tp_size
//
self
.
prefill_dp_size
if
local_tp_size_per_dp_rank
==
prefill_tp_size_per_dp_rank
:
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
1
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
elif
local_tp_size_per_dp_rank
>
prefill_tp_size_per_dp_rank
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
//
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
required_dst_info_num
=
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
else
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self
.
target_tp_ranks
=
[
rank
for
rank
in
range
(
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
+
1
)
*
(
prefill_tp_size_per_dp_rank
//
local_tp_size_per_dp_rank
),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
target_dp_group
}
_
{
self
.
target_tp_rank
}
"
)
if
bootstrap_key
not
in
self
.
kv_mgr
.
connection_pool
:
bootstrap_infos
=
[]
for
target_tp_rank
in
self
.
target_tp_ranks
:
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
target_tp_rank
,
self
.
target_dp_group
,
)
if
bootstrap_info
is
not
None
:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info
[
"is_dummy"
]
=
not
bool
(
target_tp_rank
==
self
.
target_tp_rank
or
self
.
target_tp_rank
is
None
)
bootstrap_infos
.
append
(
bootstrap_info
)
else
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
and target_dp_group:
{
self
.
target_dp_group
}
"
)
self
.
bootstrap_infos
=
bootstrap_infos
if
len
(
self
.
bootstrap_infos
)
==
0
:
logger
.
error
(
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
else
:
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
=
self
.
bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self
.
_register_kv_args
()
else
:
self
.
bootstrap_infos
=
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
assert
len
(
self
.
bootstrap_infos
)
>
0
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
,
target_dp_group
):
"""Fetch the bootstrap info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
&target_dp_group=
{
target_dp_group
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
return
bootstrap_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
def
_get_prefill_dp_size_from_server
(
self
)
->
int
:
"""Fetch the prefill parallel info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
-
1
}
&target_dp_group=
{
-
1
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_parallel_info
=
response
.
json
()
return
int
(
prefill_parallel_info
[
"prefill_tp_size"
]),
int
(
prefill_parallel_info
[
"prefill_dp_size"
]
)
else
:
logger
.
error
(
f
"Failed to get prefill parallel info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill parallel info from bootstrap:
{
e
}
"
)
return
None
@
classmethod
def
_connect
(
cls
,
endpoint
:
str
):
with
cls
.
_global_lock
:
if
endpoint
not
in
cls
.
_socket_cache
:
sock
=
cls
.
_ctx
.
socket
(
zmq
.
PUSH
)
sock
.
connect
(
endpoint
)
cls
.
_socket_cache
[
endpoint
]
=
sock
cls
.
_socket_locks
[
endpoint
]
=
threading
.
Lock
()
return
cls
.
_socket_cache
[
endpoint
],
cls
.
_socket_locks
[
endpoint
]
def
_register_kv_args
(
self
):
pass
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
class
CommonKVBootstrapServer
(
BaseKVBootstrapServer
):
def
__init__
(
self
,
port
:
int
):
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
tp_size
=
None
self
.
dp_size
=
None
self
.
tp_size_per_dp_rank
=
None
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]]
=
{}
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
self
.
run
()
def
run
(
self
):
self
.
thread
.
start
()
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/route"
,
self
.
_handle_route
)
async
def
_handle_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
role
=
data
[
"role"
]
tp_size
=
data
[
"tp_size"
]
dp_size
=
data
[
"dp_size"
]
rank_ip
=
data
[
"rank_ip"
]
rank_port
=
int
(
data
[
"rank_port"
])
engine_rank
=
int
(
data
[
"engine_rank"
])
if
self
.
tp_size
is
None
:
self
.
tp_size
=
tp_size
if
self
.
dp_size
is
None
:
self
.
dp_size
=
dp_size
tp_size_per_dp_rank
=
tp_size
//
dp_size
if
self
.
tp_size_per_dp_rank
==
None
:
self
.
tp_size_per_dp_rank
=
tp_size_per_dp_rank
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
dp_group
=
engine_rank
//
tp_size_per_dp_rank
tp_rank_in_dp_group
=
engine_rank
%
tp_size_per_dp_rank
async
with
self
.
lock
:
if
dp_group
not
in
self
.
prefill_port_table
:
self
.
prefill_port_table
[
dp_group
]
=
{}
self
.
prefill_port_table
[
dp_group
][
tp_rank_in_dp_group
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
}
logger
.
debug
(
f
"Register Prefill bootstrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
target_dp_group
=
request
.
query
.
get
(
"target_dp_group"
)
if
not
engine_rank
or
not
target_dp_group
:
return
web
.
Response
(
text
=
"Missing inputs for bootstrap server."
,
status
=
400
)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if
int
(
engine_rank
)
==
-
1
and
int
(
target_dp_group
)
==
-
1
:
prefill_parallel_info
=
{
"prefill_tp_size"
:
self
.
tp_size
,
"prefill_dp_size"
:
self
.
dp_size
,
}
return
web
.
json_response
(
prefill_parallel_info
,
status
=
200
)
# Find corresponding prefill info
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
[
int
(
target_dp_group
)][
int
(
engine_rank
)
]
if
bootstrap_info
is
not
None
:
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Bootstrap info not Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
# Event Loop
self
.
_loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
self
.
_loop
)
self
.
_runner
=
web
.
AppRunner
(
self
.
app
)
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
setup
())
site
=
web
.
TCPSite
(
self
.
_runner
,
port
=
self
.
port
)
self
.
_loop
.
run_until_complete
(
site
.
start
())
self
.
_loop
.
run_forever
()
except
Exception
as
e
:
logger
.
error
(
f
"Server error:
{
str
(
e
)
}
"
)
finally
:
# Cleanup
self
.
_loop
.
run_until_complete
(
self
.
_runner
.
cleanup
())
self
.
_loop
.
close
()
def
close
(
self
):
"""Shutdown"""
if
self
.
_loop
is
not
None
and
self
.
_loop
.
is_running
():
self
.
_loop
.
call_soon_threadsafe
(
self
.
_loop
.
stop
)
logger
.
info
(
"Stopping server loop..."
)
if
self
.
thread
.
is_alive
():
self
.
thread
.
join
(
timeout
=
2
)
logger
.
info
(
"Server thread stopped"
)
def
poll
(
self
)
->
KVPoll
:
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
e806f708
...
...
@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
group_concurrent_contiguous
,
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_free_port
,
...
...
@@ -41,23 +44,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
__name__
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
class
KVTransferError
(
Exception
):
def
__init__
(
self
,
bootstrap_room
:
int
,
failure_reason
:
str
):
super
().
__init__
(
failure_reason
)
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
e806f708
This diff is collapsed.
Click to expand it.
python/sglang/srt/disaggregation/utils.py
View file @
e806f708
...
...
@@ -13,7 +13,7 @@ import requests
import
torch
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
from
sglang.srt.utils
import
get_ip
,
get_local_ip_by_remote
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -279,3 +279,20 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment