Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
88f9c347
Unverified
Commit
88f9c347
authored
Jun 15, 2025
by
Byron Hsu
Committed by
GitHub
Jun 15, 2025
Browse files
[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)
parent
fff10809
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
24 additions
and
26 deletions
+24
-26
python/sglang/srt/disaggregation/base/conn.py
python/sglang/srt/disaggregation/base/conn.py
+2
-2
python/sglang/srt/disaggregation/common/utils.py
python/sglang/srt/disaggregation/common/utils.py
+2
-2
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+2
-4
python/sglang/srt/disaggregation/fake/conn.py
python/sglang/srt/disaggregation/fake/conn.py
+1
-1
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+9
-9
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+7
-7
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+0
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
No files found.
python/sglang/srt/disaggregation/base/conn.py
View file @
88f9c347
...
@@ -70,7 +70,7 @@ class BaseKVSender(ABC):
...
@@ -70,7 +70,7 @@ class BaseKVSender(ABC):
...
...
@
abstractmethod
@
abstractmethod
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
]):
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
32
]):
"""
"""
Send the kv cache at the given kv indices to the decoder server
Send the kv cache at the given kv indices to the decoder server
"""
"""
...
@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC):
...
@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC):
):
...
):
...
@
abstractmethod
@
abstractmethod
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
aux_index
:
Optional
[
int
]
=
None
):
"""
"""
Notify the prefill server about the kv indices and aux index
Notify the prefill server about the kv indices and aux index
"""
"""
...
...
python/sglang/srt/disaggregation/common/utils.py
View file @
88f9c347
...
@@ -26,8 +26,8 @@ class FastQueue:
...
@@ -26,8 +26,8 @@ class FastQueue:
def
group_concurrent_contiguous
(
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int
64
],
dst_indices
:
npt
.
NDArray
[
np
.
int
64
]
src_indices
:
npt
.
NDArray
[
np
.
int
32
],
dst_indices
:
npt
.
NDArray
[
np
.
int
32
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int
64
]],
List
[
npt
.
NDArray
[
np
.
int
64
]]]:
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int
32
]],
List
[
npt
.
NDArray
[
np
.
int
32
]]]:
"""Vectorised NumPy implementation."""
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
if
src_indices
.
size
==
0
:
return
[],
[]
return
[],
[]
...
...
python/sglang/srt/disaggregation/decode.py
View file @
88f9c347
...
@@ -158,6 +158,7 @@ class DecodePreallocQueue:
...
@@ -158,6 +158,7 @@ class DecodePreallocQueue:
bootstrap_port
:
int
,
bootstrap_port
:
int
,
max_total_num_tokens
:
int
,
max_total_num_tokens
:
int
,
prefill_pp_size
:
int
,
prefill_pp_size
:
int
,
num_reserved_decode_tokens
:
int
,
transfer_backend
:
TransferBackend
,
transfer_backend
:
TransferBackend
,
):
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
...
@@ -178,9 +179,7 @@ class DecodePreallocQueue:
...
@@ -178,9 +179,7 @@ class DecodePreallocQueue:
self
.
bootstrap_port
=
bootstrap_port
self
.
bootstrap_port
=
bootstrap_port
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
prefill_pp_size
=
prefill_pp_size
self
.
prefill_pp_size
=
prefill_pp_size
self
.
num_reserved_decode_tokens
=
int
(
self
.
num_reserved_decode_tokens
=
num_reserved_decode_tokens
os
.
environ
.
get
(
"SGLANG_NUM_RESERVED_DECODE_TOKENS"
,
"512"
)
)
self
.
transfer_backend
=
transfer_backend
self
.
transfer_backend
=
transfer_backend
# Queue for requests pending pre-allocation
# Queue for requests pending pre-allocation
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
queue
:
List
[
DecodeRequest
]
=
[]
...
@@ -404,7 +403,6 @@ class DecodePreallocQueue:
...
@@ -404,7 +403,6 @@ class DecodePreallocQueue:
]
]
.
cpu
()
.
cpu
()
.
numpy
()
.
numpy
()
.
astype
(
np
.
int64
)
)
)
decode_req
.
metadata_buffer_index
=
(
decode_req
.
metadata_buffer_index
=
(
...
...
python/sglang/srt/disaggregation/fake/conn.py
View file @
88f9c347
...
@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender):
...
@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender):
def
send
(
def
send
(
self
,
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
):
):
self
.
has_sent
=
True
self
.
has_sent
=
True
logger
.
info
(
f
"FakeKVSender send with kv_indices:
{
kv_indices
}
"
)
logger
.
info
(
f
"FakeKVSender send with kv_indices:
{
kv_indices
}
"
)
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
88f9c347
...
@@ -59,7 +59,7 @@ class KVTransferError(Exception):
...
@@ -59,7 +59,7 @@ class KVTransferError(Exception):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TransferKVChunk
:
class
TransferKVChunk
:
room
:
int
room
:
int
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
64
]
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
32
]
index_slice
:
slice
index_slice
:
slice
is_last
:
bool
is_last
:
bool
prefill_aux_index
:
Optional
[
int
]
prefill_aux_index
:
Optional
[
int
]
...
@@ -72,7 +72,7 @@ class TransferInfo:
...
@@ -72,7 +72,7 @@ class TransferInfo:
endpoint
:
str
endpoint
:
str
dst_port
:
int
dst_port
:
int
mooncake_session_id
:
str
mooncake_session_id
:
str
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
64
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
32
]
dst_aux_index
:
int
dst_aux_index
:
int
required_dst_info_num
:
int
required_dst_info_num
:
int
is_dummy
:
bool
is_dummy
:
bool
...
@@ -81,10 +81,10 @@ class TransferInfo:
...
@@ -81,10 +81,10 @@ class TransferInfo:
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
if
msg
[
4
]
==
b
""
and
msg
[
5
]
==
b
""
:
if
msg
[
4
]
==
b
""
and
msg
[
5
]
==
b
""
:
is_dummy
=
True
is_dummy
=
True
dst_kv_indices
=
np
.
array
([],
dtype
=
np
.
int
64
)
dst_kv_indices
=
np
.
array
([],
dtype
=
np
.
int
32
)
dst_aux_index
=
None
dst_aux_index
=
None
else
:
else
:
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int
64
)
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int
32
)
dst_aux_index
=
int
(
msg
[
5
].
decode
(
"ascii"
))
dst_aux_index
=
int
(
msg
[
5
].
decode
(
"ascii"
))
is_dummy
=
False
is_dummy
=
False
return
cls
(
return
cls
(
...
@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
def
send_kvcache
(
def
send_kvcache
(
self
,
self
,
mooncake_session_id
:
str
,
mooncake_session_id
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
executor
:
concurrent
.
futures
.
ThreadPoolExecutor
,
executor
:
concurrent
.
futures
.
ThreadPoolExecutor
,
):
):
# Group by indices
# Group by indices
...
@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
def
add_transfer_request
(
def
add_transfer_request
(
self
,
self
,
bootstrap_room
:
int
,
bootstrap_room
:
int
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
index_slice
:
slice
,
index_slice
:
slice
,
is_last
:
bool
,
is_last
:
bool
,
aux_index
:
Optional
[
int
]
=
None
,
aux_index
:
Optional
[
int
]
=
None
,
...
@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
...
@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
def
send
(
def
send
(
self
,
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
):
):
index_slice
=
slice
(
self
.
curr_idx
,
self
.
curr_idx
+
len
(
kv_indices
))
index_slice
=
slice
(
self
.
curr_idx
,
self
.
curr_idx
+
len
(
kv_indices
))
self
.
curr_idx
+=
len
(
kv_indices
)
self
.
curr_idx
+=
len
(
kv_indices
)
...
@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
cls
.
_socket_locks
[
endpoint
]
=
threading
.
Lock
()
cls
.
_socket_locks
[
endpoint
]
=
threading
.
Lock
()
return
cls
.
_socket_cache
[
endpoint
],
cls
.
_socket_locks
[
endpoint
]
return
cls
.
_socket_cache
[
endpoint
],
cls
.
_socket_locks
[
endpoint
]
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
for
bootstrap_info
in
self
.
bootstrap_infos
:
self
.
prefill_server_url
=
(
self
.
prefill_server_url
=
(
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
88f9c347
...
@@ -44,7 +44,7 @@ class TransferInfo:
...
@@ -44,7 +44,7 @@ class TransferInfo:
agent_metadata
:
bytes
agent_metadata
:
bytes
agent_name
:
str
agent_name
:
str
dst_kv_ptrs
:
list
[
int
]
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
64
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
32
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_index
:
int
dst_aux_index
:
int
dst_gpu_id
:
int
dst_gpu_id
:
int
...
@@ -62,7 +62,7 @@ class TransferInfo:
...
@@ -62,7 +62,7 @@ class TransferInfo:
agent_metadata
=
msg
[
3
],
agent_metadata
=
msg
[
3
],
agent_name
=
msg
[
4
].
decode
(
"ascii"
),
agent_name
=
msg
[
4
].
decode
(
"ascii"
),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
6
],
dtype
=
np
.
int
64
),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
6
],
dtype
=
np
.
int
32
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
7
])
//
8
}
Q"
,
msg
[
7
])),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
7
])
//
8
}
Q"
,
msg
[
7
])),
dst_aux_index
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
dst_aux_index
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
9
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
9
].
decode
(
"ascii"
)),
...
@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
...
@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
def
send_kvcache
(
def
send_kvcache
(
self
,
self
,
peer_name
:
str
,
peer_name
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
dst_gpu_id
:
int
,
dst_gpu_id
:
int
,
notif
:
str
,
notif
:
str
,
):
):
...
@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
...
@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
def
add_transfer_request
(
def
add_transfer_request
(
self
,
self
,
bootstrap_room
:
int
,
bootstrap_room
:
int
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
index_slice
:
slice
,
index_slice
:
slice
,
is_last
:
bool
,
is_last
:
bool
,
chunk_id
:
int
,
chunk_id
:
int
,
...
@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
...
@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
def
send
(
def
send
(
self
,
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
):
):
index_slice
=
slice
(
self
.
curr_idx
,
self
.
curr_idx
+
len
(
kv_indices
))
index_slice
=
slice
(
self
.
curr_idx
,
self
.
curr_idx
+
len
(
kv_indices
))
self
.
curr_idx
+=
len
(
kv_indices
)
self
.
curr_idx
+=
len
(
kv_indices
)
...
@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
started_transfer
=
False
self
.
started_transfer
=
False
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel_rank
)
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel_rank
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
64
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int
32
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
for
bootstrap_info
in
self
.
bootstrap_infos
:
self
.
prefill_server_url
=
(
self
.
prefill_server_url
=
(
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
88f9c347
...
@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin:
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
.
cpu
()
.
cpu
()
.
numpy
()
.
numpy
()
.
astype
(
np
.
int64
)
)
)
req
.
start_send_idx
=
end_idx
req
.
start_send_idx
=
end_idx
if
last_chunk
:
if
last_chunk
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
88f9c347
...
@@ -656,6 +656,7 @@ class Scheduler(
...
@@ -656,6 +656,7 @@ class Scheduler(
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
max_total_num_tokens
=
self
.
max_total_num_tokens
,
max_total_num_tokens
=
self
.
max_total_num_tokens
,
prefill_pp_size
=
self
.
server_args
.
disaggregation_prefill_pp
,
prefill_pp_size
=
self
.
server_args
.
disaggregation_prefill_pp
,
num_reserved_decode_tokens
=
self
.
server_args
.
num_reserved_decode_tokens
,
transfer_backend
=
self
.
transfer_backend
,
transfer_backend
=
self
.
transfer_backend
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment