Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
10143e1a
Unverified
Commit
10143e1a
authored
Jul 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 13, 2024
Browse files
Memorypool chunked prefetch (#614)
parent
65c65776
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
39 deletions
+30
-39
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+0
-7
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+0
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+7
-24
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-2
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+23
-4
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
10143e1a
...
...
@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
if
input_metadata
.
out_cache_loc
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
elif
input_metadata
.
out_cache_cont_start
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_v
else
:
raise
RuntimeError
()
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
10143e1a
...
...
@@ -104,8 +104,6 @@ class CudaGraphRunner:
prefix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
10143e1a
...
...
@@ -275,8 +275,6 @@ class Batch:
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# For processing logprobs
return_logprob
:
bool
=
False
...
...
@@ -566,21 +564,12 @@ class Batch:
# Alloc mem
bs
=
len
(
self
.
reqs
)
alloc_res
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
bs
)
if
alloc_res
is
None
:
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
if
self
.
out_cache_loc
is
None
:
print
(
"Decode out of memory. This should never happen."
)
self
.
tree_cache
.
pretty_print
()
exit
()
self
.
out_cache_cont_start
=
None
self
.
out_cache_cont_end
=
None
else
:
self
.
out_cache_loc
=
alloc_res
[
0
]
self
.
out_cache_cont_start
=
alloc_res
[
1
]
self
.
out_cache_cont_end
=
alloc_res
[
2
]
if
self
.
out_cache_loc
is
None
:
print
(
"Decode out of memory. This should never happen."
)
self
.
tree_cache
.
pretty_print
()
exit
()
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
1
...
...
@@ -594,7 +583,7 @@ class Batch:
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
...
@@ -622,7 +611,7 @@ class Batch:
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
...
@@ -729,8 +718,6 @@ class InputMetadata:
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# Output options
return_logprob
:
bool
=
False
...
...
@@ -757,8 +744,6 @@ class InputMetadata:
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
...
...
@@ -811,8 +796,6 @@ class InputMetadata:
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
10143e1a
...
...
@@ -245,8 +245,6 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
...
...
python/sglang/srt/memory_pool.py
View file @
10143e1a
...
...
@@ -50,6 +50,10 @@ class TokenToKVPool:
for
_
in
range
(
layer_num
)
]
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
256
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
...
...
@@ -59,14 +63,29 @@ class TokenToKVPool:
return
self
.
kv_data
[
layer_id
][:,
1
]
def
alloc
(
self
,
need_size
):
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
if
select_index
.
shape
[
0
]
<
need_size
:
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
select_index
.
to
(
torch
.
int32
)
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
]
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
self
.
add_refs
(
select_index
)
return
select_index
.
to
(
torch
.
int32
)
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
ret_index
.
to
(
torch
.
int32
)
def
alloc_contiguous
(
self
,
need_size
):
# NOTE: This function is deprecated.
empty_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
if
empty_index
.
shape
[
0
]
<
need_size
:
return
None
...
...
@@ -89,7 +108,7 @@ class TokenToKVPool:
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
def
available_size
(
self
):
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
+
len
(
self
.
prefetch_buffer
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref_ct
+=
len
(
token_index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment