Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
10143e1a
Unverified
Commit
10143e1a
authored
Jul 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 13, 2024
Browse files
Memorypool chunked prefetch (#614)
parent
65c65776
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
39 deletions
+30
-39
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+0
-7
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+0
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+7
-24
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-2
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+23
-4
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
10143e1a
...
@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
...
@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
if
input_metadata
.
out_cache_loc
is
not
None
:
if
input_metadata
.
out_cache_loc
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
elif
input_metadata
.
out_cache_cont_start
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_cont_start
:
input_metadata
.
out_cache_cont_end
]
=
cache_v
else
:
else
:
raise
RuntimeError
()
raise
RuntimeError
()
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
10143e1a
...
@@ -104,8 +104,6 @@ class CudaGraphRunner:
...
@@ -104,8 +104,6 @@ class CudaGraphRunner:
prefix_lens
=
None
,
prefix_lens
=
None
,
position_ids_offsets
=
position_ids_offsets
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
top_logprobs_nums
=
0
,
skip_flashinfer_init
=
True
,
skip_flashinfer_init
=
True
,
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
10143e1a
...
@@ -275,8 +275,6 @@ class Batch:
...
@@ -275,8 +275,6 @@ class Batch:
prefix_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
...
@@ -566,8 +564,6 @@ class Batch:
...
@@ -566,8 +564,6 @@ class Batch:
# Alloc mem
# Alloc mem
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
alloc_res
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
bs
)
if
alloc_res
is
None
:
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
if
self
.
out_cache_loc
is
None
:
if
self
.
out_cache_loc
is
None
:
...
@@ -575,13 +571,6 @@ class Batch:
...
@@ -575,13 +571,6 @@ class Batch:
self
.
tree_cache
.
pretty_print
()
self
.
tree_cache
.
pretty_print
()
exit
()
exit
()
self
.
out_cache_cont_start
=
None
self
.
out_cache_cont_end
=
None
else
:
self
.
out_cache_loc
=
alloc_res
[
0
]
self
.
out_cache_cont_start
=
alloc_res
[
1
]
self
.
out_cache_cont_end
=
alloc_res
[
2
]
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
1
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
]
=
self
.
out_cache_loc
...
@@ -594,7 +583,7 @@ class Batch:
...
@@ -594,7 +583,7 @@ class Batch:
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
prefix_lens
=
None
self
.
prefix_lens
=
None
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
position_ids_offsets
=
self
.
position_ids_offsets
[
new_indices
]
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
@@ -622,7 +611,7 @@ class Batch:
...
@@ -622,7 +611,7 @@ class Batch:
self
.
position_ids_offsets
=
torch
.
concat
(
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
)
self
.
out_cache_loc
=
self
.
out_cache_cont_start
=
self
.
out_cache_cont_end
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
...
@@ -729,8 +718,6 @@ class InputMetadata:
...
@@ -729,8 +718,6 @@ class InputMetadata:
# Output location of the KV cache
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
int
=
None
# Output options
# Output options
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
...
@@ -757,8 +744,6 @@ class InputMetadata:
...
@@ -757,8 +744,6 @@ class InputMetadata:
prefix_lens
,
prefix_lens
,
position_ids_offsets
,
position_ids_offsets
,
out_cache_loc
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
skip_flashinfer_init
=
False
,
...
@@ -811,8 +796,6 @@ class InputMetadata:
...
@@ -811,8 +796,6 @@ class InputMetadata:
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
extend_no_prefix
=
extend_no_prefix
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
10143e1a
...
@@ -245,8 +245,6 @@ class ModelRunner:
...
@@ -245,8 +245,6 @@ class ModelRunner:
prefix_lens
=
batch
.
prefix_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
)
)
...
...
python/sglang/srt/memory_pool.py
View file @
10143e1a
...
@@ -50,6 +50,10 @@ class TokenToKVPool:
...
@@ -50,6 +50,10 @@ class TokenToKVPool:
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
256
self
.
clear
()
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
def
get_key_buffer
(
self
,
layer_id
):
...
@@ -59,14 +63,29 @@ class TokenToKVPool:
...
@@ -59,14 +63,29 @@ class TokenToKVPool:
return
self
.
kv_data
[
layer_id
][:,
1
]
return
self
.
kv_data
[
layer_id
][:,
1
]
def
alloc
(
self
,
need_size
):
def
alloc
(
self
,
need_size
):
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
select_index
.
shape
[
0
]
<
need_size
:
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
select_index
.
to
(
torch
.
int32
)
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
alloc_size
]
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
return
None
self
.
add_refs
(
select_index
)
self
.
add_refs
(
select_index
)
return
select_index
.
to
(
torch
.
int32
)
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
ret_index
.
to
(
torch
.
int32
)
def
alloc_contiguous
(
self
,
need_size
):
def
alloc_contiguous
(
self
,
need_size
):
# NOTE: This function is deprecated.
empty_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
empty_index
=
torch
.
nonzero
(
self
.
mem_state
==
0
).
squeeze
(
1
)[:
need_size
]
if
empty_index
.
shape
[
0
]
<
need_size
:
if
empty_index
.
shape
[
0
]
<
need_size
:
return
None
return
None
...
@@ -89,7 +108,7 @@ class TokenToKVPool:
...
@@ -89,7 +108,7 @@ class TokenToKVPool:
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
def
available_size
(
self
):
def
available_size
(
self
):
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
+
len
(
self
.
prefetch_buffer
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref_ct
+=
len
(
token_index
)
self
.
total_ref_ct
+=
len
(
token_index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment