Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
10143e1a
"vscode:/vscode.git/clone" did not exist on "ffe735bacfa918e6e21a5d751fd07afab3faaa15"
Unverified
Commit
10143e1a
authored
Jul 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 13, 2024
Browse files
Memorypool chunked prefetch (#614)
parent
65c65776
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
39 deletions
+30
-39
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+0
-7
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+0
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+7
-24
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-2
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+23
-4
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
10143e1a
...
...
@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
if
input_metadata
.
out_cache_loc
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
elif
input_metadata
.
out_cache_cont_start
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_v
else
:
raise
RuntimeError
()
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
10143e1a
...
...
@@ -104,8 +104,6 @@ class CudaGraphRunner:
prefix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
10143e1a
...
...
@@ -275,8 +275,6 @@ class Batch:
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# For processing logprobs
return_logprob
:
bool
=
False
...
...
@@ -566,21 +564,12 @@ class Batch:
# Alloc mem
bs
=
len
(
self
.
reqs
)
alloc_res
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
bs
)
if
alloc_res
is
None
:
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
if
self
.
out_cache_loc
is
None
:
print
(
"Decode out of memory. This should never happen."
)
self
.
tree_cache
.
pretty_print
()
exit
()
self
.
out_cache_cont_start
=
None
self
.
out_cache_cont_end
=
None
else
:
self
.
out_cache_loc
=
alloc_res
[
0
]
self
.
out_cache_cont_start
=
alloc_res
[
1
]
self
.
out_cache_cont_end
=
alloc_res
[
2
]
if
self
.
out_cache_loc
is
None
:
print
(
"Decode out of memory. This should never happen."
)
self
.
tree_cache
.
pretty_print
()
exit
()
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
1
...
...
@@ -594,7 +583,7 @@ class Batch:
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
...
@@ -622,7 +611,7 @@ class Batch:
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
...
@@ -729,8 +718,6 @@ class InputMetadata:
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# Output options
return_logprob
:
bool
=
False
...
...
@@ -757,8 +744,6 @@ class InputMetadata:
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
...
...
@@ -811,8 +796,6 @@ class InputMetadata:
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
10143e1a
...
...
@@ -245,8 +245,6 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
)
...
...
python/sglang/srt/memory_pool.py
View file @
10143e1a
...
...
@@ -50,6 +50,10 @@ class TokenToKVPool:
for
_
in
range
(
layer_num
)
]
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
256
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
...
...
@@ -59,14 +63,29 @@ class TokenToKVPool:
return
self
.
kv_data
[
layer_id
][:,
1
]
def
alloc
(
self
,
need_size
):
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
if
select_index
.
shape
[
0
]
<
need_size
:
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
select_index
.
to
(
torch
.
int32
)
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
]
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
self
.
add_refs
(
select_index
)
return
select_index
.
to
(
torch
.
int32
)
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
ret_index
.
to
(
torch
.
int32
)
def
alloc_contiguous
(
self
,
need_size
):
# NOTE: This function is deprecated.
empty_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
if
empty_index
.
shape
[
0
]
<
need_size
:
return
None
...
...
@@ -89,7 +108,7 @@ class TokenToKVPool:
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
def
available_size
(
self
):
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
+
len
(
self
.
prefetch_buffer
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref_ct
+=
len
(
token_index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment