Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab4b5606
Unverified
Commit
ab4b5606
authored
Apr 19, 2025
by
Byron Hsu
Committed by
GitHub
Apr 19, 2025
Browse files
[PD] Support page size > 1 (#5561)
parent
20f1c8e3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
9 deletions
+58
-9
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+29
-5
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+8
-2
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+15
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+6
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
ab4b5606
...
@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
get_kv_class
,
get_kv_class
,
kv_to_page_indices
,
poll_and_all_reduce
,
poll_and_all_reduce
,
)
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
@@ -205,7 +206,10 @@ class DecodePreallocQueue:
...
@@ -205,7 +206,10 @@ class DecodePreallocQueue:
self
.
req_to_metadata_buffer_idx_allocator
.
alloc
()
self
.
req_to_metadata_buffer_idx_allocator
.
alloc
()
)
)
assert
decode_req
.
metadata_buffer_index
is
not
None
assert
decode_req
.
metadata_buffer_index
is
not
None
decode_req
.
kv_receiver
.
init
(
kv_indices
,
decode_req
.
metadata_buffer_index
)
page_indices
=
kv_to_page_indices
(
kv_indices
,
self
.
token_to_kv_pool_allocator
.
page_size
)
decode_req
.
kv_receiver
.
init
(
page_indices
,
decode_req
.
metadata_buffer_index
)
preallocated_reqs
.
append
(
decode_req
)
preallocated_reqs
.
append
(
decode_req
)
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
...
@@ -245,10 +249,30 @@ class DecodePreallocQueue:
...
@@ -245,10 +249,30 @@ class DecodePreallocQueue:
assert
req_pool_indices
is
not
None
assert
req_pool_indices
is
not
None
req
.
req_pool_idx
=
req_pool_indices
[
0
]
req
.
req_pool_idx
=
req_pool_indices
[
0
]
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
kv_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
kv_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
)
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
)
)
)
else
:
num_tokens
=
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
)
kv_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
prefix_lens
=
torch
.
tensor
(
[
0
],
dtype
=
torch
.
int64
,
device
=
self
.
token_to_kv_pool_allocator
.
device
,
),
seq_lens
=
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int64
,
device
=
self
.
token_to_kv_pool_allocator
.
device
,
),
last_loc
=
torch
.
tensor
(
[
-
1
],
dtype
=
torch
.
int64
,
device
=
self
.
token_to_kv_pool_allocator
.
device
,
),
extend_num_tokens
=
num_tokens
,
)
assert
kv_loc
is
not
None
assert
kv_loc
is
not
None
self
.
req_to_token_pool
.
write
((
req
.
req_pool_idx
,
slice
(
0
,
len
(
kv_loc
))),
kv_loc
)
self
.
req_to_token_pool
.
write
((
req
.
req_pool_idx
,
slice
(
0
,
len
(
kv_loc
))),
kv_loc
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
ab4b5606
...
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
...
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
get_kv_class
,
get_kv_class
,
kv_to_page_indices
,
kv_to_page_num
,
poll_and_all_reduce
,
poll_and_all_reduce
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
...
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
...
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
self
.
req_to_metadata_buffer_idx_allocator
.
alloc
()
self
.
req_to_metadata_buffer_idx_allocator
.
alloc
()
)
)
assert
req
.
metadata_buffer_index
is
not
None
assert
req
.
metadata_buffer_index
is
not
None
req
.
disagg_kv_sender
.
init
(
num_kv_indices
,
req
.
metadata_buffer_index
)
num_pages
=
kv_to_page_num
(
num_kv_indices
,
self
.
token_to_kv_pool
.
page_size
)
req
.
disagg_kv_sender
.
init
(
num_pages
,
req
.
metadata_buffer_index
)
bootstrapped_reqs
.
append
(
req
)
bootstrapped_reqs
.
append
(
req
)
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
...
@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
req
.
metadata_buffer_index
,
token_id
req
.
metadata_buffer_index
,
token_id
)
)
is_last
=
token_id
is
not
None
is_last
=
token_id
is
not
None
req
.
disagg_kv_sender
.
send
(
kv_indices
,
slice
(
start_idx
,
end_idx
),
is_last
)
page_indices
=
kv_to_page_indices
(
kv_indices
,
self
.
token_to_kv_pool_allocator
.
page_size
)
req
.
disagg_kv_sender
.
send
(
page_indices
,
slice
(
start_idx
,
end_idx
),
is_last
)
python/sglang/srt/disaggregation/utils.py
View file @
ab4b5606
...
@@ -4,6 +4,7 @@ from collections import deque
...
@@ -4,6 +4,7 @@ from collections import deque
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
...
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
}
return
class_mapping
.
get
(
class_type
)
return
class_mapping
.
get
(
class_type
)
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
def
kv_to_page_indices
(
kv_indices
:
np
.
ndarray
,
page_size
:
int
):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if
page_size
==
1
:
# shortcut
return
kv_indices
return
kv_indices
[::
page_size
]
//
page_size
def
kv_to_page_num
(
num_kv_indices
:
int
,
page_size
:
int
):
# ceil(num_kv_indices / page_size)
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
python/sglang/srt/mem_cache/memory_pool.py
View file @
ab4b5606
...
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
...
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
self
.
get_key_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
layer_num
)
self
.
get_key_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
layer_num
)]
]
+
[
self
.
get_value_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
layer_num
)]
kv_item_lens
=
[
kv_item_lens
=
[
self
.
get_key_buffer
(
i
)[
0
].
nbytes
for
i
in
range
(
self
.
layer_num
)
self
.
get_key_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
]
+
[
self
.
get_value_buffer
(
i
)[
0
].
nbytes
for
i
in
range
(
self
.
layer_num
)]
for
i
in
range
(
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
for
i
in
range
(
self
.
layer_num
)
]
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
# Todo: different memory layout
# Todo: different memory layout
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment