Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
db0cc57e
Unverified
Commit
db0cc57e
authored
Jun 14, 2025
by
Byron Hsu
Committed by
GitHub
Jun 14, 2025
Browse files
[PD] Support decode retract and update decode.py (#7196)
parent
349bb2c9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
376 additions
and
41 deletions
+376
-41
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+186
-38
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+11
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+9
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+41
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+127
-0
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
db0cc57e
...
...
@@ -31,7 +31,7 @@ import numpy as np
import
torch
from
torch.distributed
import
ProcessGroup
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
FAKE_BOOTSTRAP_HOST
,
DisaggregationMode
,
...
...
@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce
,
prepare_abort
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
ScheduleBatch
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
(
KVCache
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
@@ -145,7 +153,11 @@ class DecodePreallocQueue:
gloo_group
:
ProcessGroup
,
tp_rank
:
int
,
tp_size
:
int
,
dp_size
:
int
,
gpu_id
:
int
,
bootstrap_port
:
int
,
max_total_num_tokens
:
int
,
prefill_pp_size
:
int
,
transfer_backend
:
TransferBackend
,
):
self
.
req_to_token_pool
=
req_to_token_pool
...
...
@@ -161,25 +173,35 @@ class DecodePreallocQueue:
self
.
gloo_group
=
gloo_group
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
dp_size
=
dp_size
self
.
gpu_id
=
gpu_id
self
.
bootstrap_port
=
bootstrap_port
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
prefill_pp_size
=
prefill_pp_size
self
.
num_reserved_decode_tokens
=
int
(
os
.
environ
.
get
(
"SGLANG_NUM_RESERVED_DECODE_TOKENS"
,
"512"
)
)
self
.
transfer_backend
=
transfer_backend
# Queue for requests pending pre-allocation
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
transfer_backend
=
transfer_backend
self
.
retracted_queue
:
List
[
Req
]
=
[]
self
.
prefill_pp_size
=
prefill_pp_size
self
.
kv_manager
=
self
.
_init_kv_manager
()
def
_init_kv_manager
(
self
)
->
BaseKVManager
:
kv_args
=
KVArgs
()
kv_args
.
engine_rank
=
self
.
tp_rank
kv_args_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
KVARGS
)
kv_args
=
kv_args_class
()
attn_tp_size
=
self
.
tp_size
//
self
.
dp_size
kv_args
.
engine_rank
=
self
.
tp_rank
%
(
attn_tp_size
)
kv_args
.
decode_tp_size
=
attn_tp_size
kv_args
.
prefill_pp_size
=
self
.
prefill_pp_size
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
self
.
token_to_kv_pool
.
get_contiguous_buf_infos
()
)
if
self
.
draft_token_to_kv_pool
is
not
None
:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs
,
draft_kv_data_lens
,
draft_kv_item_lens
=
(
self
.
draft_token_to_kv_pool
.
get_contiguous_buf_infos
()
)
...
...
@@ -194,6 +216,7 @@ class DecodePreallocQueue:
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
self
.
metadata_buffers
.
get_buf_infos
()
)
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
...
...
@@ -205,27 +228,83 @@ class DecodePreallocQueue:
)
return
kv_manager
def
add
(
self
,
req
:
Req
)
->
None
:
def
add
(
self
,
req
:
Req
,
is_retracted
:
bool
=
False
)
->
None
:
"""Add a request to the pending queue."""
if
req
.
bootstrap_host
==
FAKE_BOOTSTRAP_HOST
:
# Fake transfer for warmup reqs
kv_receiver_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
RECEIVER
)
if
self
.
_check_if_req_exceed_kv_capacity
(
req
):
return
if
is_retracted
:
self
.
retracted_queue
.
append
(
req
)
else
:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
if
req
.
bootstrap_host
==
FAKE_BOOTSTRAP_HOST
:
kv_receiver_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
RECEIVER
)
else
:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
)
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
data_parallel_rank
=
req
.
data_parallel_rank
,
)
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
def
extend
(
self
,
reqs
:
List
[
Req
])
->
None
:
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
,
waiting_for_input
=
False
)
)
def
_check_if_req_exceed_kv_capacity
(
self
,
req
:
Req
)
->
bool
:
if
len
(
req
.
origin_input_ids
)
>
self
.
max_total_num_tokens
:
message
=
f
"Request
{
req
.
rid
}
exceeds the maximum number of tokens:
{
len
(
req
.
origin_input_ids
)
}
>
{
self
.
max_total_num_tokens
}
"
logger
.
error
(
message
)
prepare_abort
(
req
,
message
)
self
.
scheduler
.
stream_output
([
req
],
req
.
return_logprob
)
return
True
return
False
def
extend
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
)
->
None
:
"""Add a request to the pending queue."""
for
req
in
reqs
:
self
.
add
(
req
)
self
.
add
(
req
,
is_retracted
=
is_retracted
)
def
resume_retracted_reqs
(
self
)
->
List
[
Req
]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
# allocate memory
resumed_reqs
=
[]
indices_to_remove
=
set
()
allocatable_tokens
=
self
.
_allocatable_tokens
(
count_retracted
=
False
)
for
i
,
req
in
enumerate
(
self
.
retracted_queue
):
if
self
.
req_to_token_pool
.
available_size
()
<=
0
:
break
required_tokens_for_request
=
(
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
+
self
.
num_reserved_decode_tokens
)
if
required_tokens_for_request
>
allocatable_tokens
:
break
resumed_reqs
.
append
(
req
)
indices_to_remove
.
add
(
i
)
req
.
is_retracted
=
False
self
.
_pre_alloc
(
req
)
allocatable_tokens
-=
required_tokens_for_request
# load from cpu, release the cpu copy
req
.
load_kv_cache
(
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
)
self
.
retracted_queue
=
[
entry
for
i
,
entry
in
enumerate
(
self
.
retracted_queue
)
if
i
not
in
indices_to_remove
]
return
resumed_reqs
def
_update_handshake_waiters
(
self
)
->
None
:
if
not
self
.
queue
:
...
...
@@ -255,6 +334,8 @@ class DecodePreallocQueue:
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
,
)
else
:
raise
ValueError
(
f
"Unexpected poll case:
{
poll
}
"
)
def
pop_preallocated
(
self
)
->
List
[
DecodeRequest
]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
...
...
@@ -262,8 +343,16 @@ class DecodePreallocQueue:
preallocated_reqs
=
[]
indices_to_remove
=
set
()
allocatable_tokens
=
self
.
_allocatable_tokens
()
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
retractable_tokens
=
sum
(
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
for
r
in
self
.
scheduler
.
running_batch
.
reqs
)
allocatable_tokens
=
self
.
_allocatable_tokens
(
retractable_tokens
=
retractable_tokens
,
count_retracted
=
True
)
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
...
...
@@ -272,6 +361,7 @@ class DecodePreallocQueue:
)
indices_to_remove
.
add
(
i
)
# Then, preallocate the remaining requests if possible
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
i
in
indices_to_remove
:
continue
...
...
@@ -285,10 +375,23 @@ class DecodePreallocQueue:
if
self
.
req_to_metadata_buffer_idx_allocator
.
available_size
()
<=
0
:
break
# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
origin_input_len
=
len
(
decode_req
.
req
.
origin_input_ids
)
required_tokens_for_request
=
(
len
(
decode_req
.
req
.
origin_input_
ids
)
+
self
.
num_reserved_decode_tokens
origin_input_
len
+
self
.
num_reserved_decode_tokens
)
if
(
max
(
required_tokens_for_request
,
origin_input_len
+
decode_req
.
req
.
sampling_params
.
max_new_tokens
-
retractable_tokens
,
)
>
allocatable_tokens
):
break
if
required_tokens_for_request
>
allocatable_tokens
:
break
...
...
@@ -321,15 +424,35 @@ class DecodePreallocQueue:
return
preallocated_reqs
def
_allocatable_tokens
(
self
)
->
int
:
allocatable_tokens
=
(
self
.
token_to_kv_pool_allocator
.
available_size
()
-
self
.
num_reserved_decode_tokens
def
_allocatable_tokens
(
self
,
retractable_tokens
:
Optional
[
int
]
=
None
,
count_retracted
:
bool
=
True
)
->
int
:
need_space_for_single_req
=
(
max
(
[
x
.
sampling_params
.
max_new_tokens
+
len
(
x
.
origin_input_ids
)
-
retractable_tokens
for
x
in
self
.
scheduler
.
running_batch
.
reqs
]
)
if
retractable_tokens
is
not
None
and
len
(
self
.
scheduler
.
running_batch
.
reqs
)
>
0
else
0
)
available_size
=
self
.
token_to_kv_pool_allocator
.
available_size
()
allocatable_tokens
=
available_size
-
max
(
# preserve some space for future decode
self
.
num_reserved_decode_tokens
*
(
len
(
self
.
scheduler
.
running_batch
.
reqs
)
+
len
(
self
.
transfer_queue
.
queue
)
+
len
(
self
.
scheduler
.
waiting_queue
)
)
),
# make sure each request can finish if reach max_tokens with all other requests retracted
need_space_for_single_req
,
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
...
...
@@ -342,15 +465,27 @@ class DecodePreallocQueue:
self
.
scheduler
.
last_batch
.
reqs
)
if
count_retracted
:
allocatable_tokens
-=
sum
(
[
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
+
self
.
num_reserved_decode_tokens
for
req
in
self
.
retracted_queue
]
)
return
allocatable_tokens
def
_pre_alloc
(
self
,
req
:
Req
)
->
torch
.
Tensor
:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
1
)
assert
req_pool_indices
is
not
None
assert
(
req_pool_indices
is
not
None
),
"req_pool_indices is full! There is a bug in memory estimation."
req
.
req_pool_idx
=
req_pool_indices
[
0
]
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
kv_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
)
...
...
@@ -375,7 +510,10 @@ class DecodePreallocQueue:
),
extend_num_tokens
=
num_tokens
,
)
assert
kv_loc
is
not
None
assert
(
kv_loc
is
not
None
),
"KV cache is full! There is a bug in memory estimation."
self
.
req_to_token_pool
.
write
((
req
.
req_pool_idx
,
slice
(
0
,
len
(
kv_loc
))),
kv_loc
)
...
...
@@ -395,6 +533,7 @@ class DecodeTransferQueue:
self
,
gloo_group
:
ProcessGroup
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
tp_rank
:
int
,
metadata_buffers
:
MetadataBuffers
,
scheduler
:
Scheduler
,
tree_cache
:
BasePrefixCache
,
...
...
@@ -402,6 +541,7 @@ class DecodeTransferQueue:
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
gloo_group
=
gloo_group
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
tp_rank
=
tp_rank
self
.
metadata_buffers
=
metadata_buffers
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
...
...
@@ -412,10 +552,9 @@ class DecodeTransferQueue:
def
extend
(
self
,
decode_reqs
:
List
[
DecodeRequest
])
->
None
:
self
.
queue
.
extend
(
decode_reqs
)
def
pop_transferred
(
self
)
->
List
[
DecodeRequest
]:
def
pop_transferred
(
self
)
->
List
[
Req
]:
if
not
self
.
queue
:
return
[]
polls
=
poll_and_all_reduce
(
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
)
...
...
@@ -424,7 +563,7 @@ class DecodeTransferQueue:
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
if
poll
==
KVPoll
.
Failed
:
error_message
=
f
"Decode transfer failed for request rank=
{
self
.
scheduler
.
tp_rank
}
{
decode_req
.
req
.
rid
=
}
{
decode_req
.
req
.
bootstrap_room
=
}
"
error_message
=
f
"Decode transfer failed for request rank=
{
self
.
tp_rank
}
{
decode_req
.
req
.
rid
=
}
{
decode_req
.
req
.
bootstrap_room
=
}
"
try
:
decode_req
.
kv_receiver
.
failure_exception
()
except
Exception
as
e
:
...
...
@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
...
...
@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
...
...
@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
return
new_batch
def
process_decode_queue
(
self
:
Scheduler
):
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
resumed_reqs
=
self
.
disagg_decode_prealloc_queue
.
resume_retracted_reqs
()
self
.
waiting_queue
.
extend
(
resumed_reqs
)
if
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
>
0
:
# if there are still retracted requests, we do not allocate new requests
return
req_conns
=
self
.
disagg_decode_prealloc_queue
.
pop_preallocated
()
self
.
disagg_decode_transfer_queue
.
extend
(
req_conns
)
alloc_reqs
=
(
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
db0cc57e
...
...
@@ -25,6 +25,7 @@ from collections import deque
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
torch
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVPoll
...
...
@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
.
cpu
()
.
numpy
()
.
astype
(
np
.
int64
)
)
req
.
start_send_idx
=
end_idx
if
last_chunk
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
db0cc57e
...
...
@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req
=
self
.
reqs
[
idx
]
retracted_reqs
.
append
(
req
)
if
server_args
.
disaggregation_mode
==
"decode"
:
req
.
offload_kv_cache
(
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
)
if
isinstance
(
self
.
tree_cache
,
ChunkCache
):
# ChunkCache does not have eviction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
...
...
@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req
.
reset_for_retract
()
if
len
(
retracted_reqs
)
==
0
:
# Corner case: only one request left
raise
ValueError
(
"Failed to retract any request. No space left for only one request."
)
self
.
filter_batch
(
keep_indices
=
sorted_indices
)
# Reqs in batch are filtered
...
...
python/sglang/srt/managers/scheduler.py
View file @
db0cc57e
...
...
@@ -628,6 +628,7 @@ class Scheduler(
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
tp_rank
=
self
.
tp_rank
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
...
...
@@ -650,7 +651,11 @@ class Scheduler(
gloo_group
=
self
.
attn_tp_cpu_group
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
dp_size
=
self
.
server_args
.
dp_size
,
gpu_id
=
self
.
gpu_id
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
max_total_num_tokens
=
self
.
max_total_num_tokens
,
prefill_pp_size
=
self
.
server_args
.
disaggregation_prefill_pp
,
transfer_backend
=
self
.
transfer_backend
,
)
...
...
@@ -1124,14 +1129,14 @@ class Scheduler(
else
:
self
.
waiting_queue
.
append
(
req
)
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
]):
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
]
,
is_retracted
:
bool
=
False
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
disagg_prefill_bootstrap_queue
.
extend
(
reqs
,
self
.
model_config
.
num_key_value_heads
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# If this is a decode server, we put the request to the decode pending prealloc queue
self
.
disagg_decode_prealloc_queue
.
extend
(
reqs
)
self
.
disagg_decode_prealloc_queue
.
extend
(
reqs
,
is_retracted
)
else
:
self
.
waiting_queue
.
extend
(
reqs
)
...
...
@@ -1274,6 +1279,7 @@ class Scheduler(
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
msg
+=
f
"pre-allocated usage:
{
self
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
...
...
@@ -1575,7 +1581,7 @@ class Scheduler(
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
_extend_requests_to_queue
(
retracted_reqs
)
self
.
_extend_requests_to_queue
(
retracted_reqs
,
is_retracted
=
True
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
db0cc57e
...
...
@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
def
get_cpu_copy
(
self
,
indices
):
return
self
.
_kvcache
.
get_cpu_copy
(
indices
)
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
return
self
.
_kvcache
.
load_cpu_copy
(
kv_cache_cpu
,
indices
)
class
MHATokenToKVPool
(
KVCache
):
...
...
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
self
.
head_dim
=
head_dim
self
.
_create_buffers
()
# used for chunked cpu-offloading
self
.
chunk_size
=
8192
self
.
layer_transfer_counter
=
None
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
_is_cuda
else
None
...
...
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
]
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
for
layer_id
in
range
(
self
.
layer_num
):
kv_cache_cpu
.
append
([])
for
i
in
range
(
0
,
len
(
indices
),
self
.
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
self
.
chunk_size
]
k_cpu
=
self
.
k_buffer
[
layer_id
][
chunk_indices
].
to
(
"cpu"
,
non_blocking
=
True
)
v_cpu
=
self
.
v_buffer
[
layer_id
][
chunk_indices
].
to
(
"cpu"
,
non_blocking
=
True
)
kv_cache_cpu
[
-
1
].
append
([
k_cpu
,
v_cpu
])
torch
.
cuda
.
synchronize
()
return
kv_cache_cpu
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
torch
.
cuda
.
synchronize
()
for
layer_id
in
range
(
self
.
layer_num
):
for
i
in
range
(
0
,
len
(
indices
),
self
.
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
self
.
chunk_size
]
k_cpu
,
v_cpu
=
(
kv_cache_cpu
[
layer_id
][
i
//
self
.
chunk_size
][
0
],
kv_cache_cpu
[
layer_id
][
i
//
self
.
chunk_size
][
1
],
)
assert
k_cpu
.
shape
[
0
]
==
v_cpu
.
shape
[
0
]
==
len
(
chunk_indices
)
k_chunk
=
k_cpu
.
to
(
self
.
k_buffer
[
0
].
device
,
non_blocking
=
True
)
v_chunk
=
v_cpu
.
to
(
self
.
v_buffer
[
0
].
device
,
non_blocking
=
True
)
self
.
k_buffer
[
layer_id
][
chunk_indices
]
=
k_chunk
self
.
v_buffer
[
layer_id
][
chunk_indices
]
=
v_chunk
torch
.
cuda
.
synchronize
()
# Todo: different memory layout
def
get_flat_data
(
self
,
indices
):
# prepare a large chunk of contiguous data for efficient transfer
...
...
test/srt/test_disaggregation.py
View file @
db0cc57e
...
...
@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
class
TestDisaggregationSimulatedRetract
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang.srt.disaggregation.mini_lb"
,
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--tp"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce1"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
os
.
environ
.
pop
(
"SGLANG_TEST_RETRACT"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment