Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
69d19188
Unverified
Commit
69d19188
authored
Jul 20, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 20, 2024
Browse files
Decouple kv (#679)
parent
4b4a67f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
38 deletions
+19
-38
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+6
-32
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+12
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
69d19188
...
...
@@ -99,7 +99,7 @@ class RadixAttention(nn.Module):
else
:
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
]
,
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
)
,
causal
=
False
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
...
...
@@ -119,7 +119,7 @@ class RadixAttention(nn.Module):
o
=
input_metadata
.
flashinfer_decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
]
,
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
)
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
...
...
@@ -136,33 +136,7 @@ class RadixAttention(nn.Module):
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
kv_cache
=
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
]
_store_kv_cache
(
cache_k
,
cache_v
,
kv_cache
,
input_metadata
.
out_cache_loc
)
try
:
@
torch
.
library
.
custom_op
(
"mylib::store_kv_cache"
,
mutates_args
=
{
"kv_cache"
})
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
@
_store_kv_cache
.
register_fake
def
_
(
k
,
v
,
kv_cache
,
cache_loc
):
pass
except
:
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
k_cache
=
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
)
v_cache
=
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
)
k_cache
[
input_metadata
.
out_cache_loc
]
=
cache_k
v_cache
[
input_metadata
.
out_cache_loc
]
=
cache_v
python/sglang/srt/memory_pool.py
View file @
69d19188
...
...
@@ -57,9 +57,13 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
torch
.
empty
((
size
+
1
,
2
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
...
...
@@ -71,10 +75,13 @@ class TokenToKVPool:
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k
v_data
[
layer_id
]
[:,
0
]
return
self
.
k
_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_data
[
layer_id
][:,
1
]
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
...
...
python/sglang/srt/server.py
View file @
69d19188
...
...
@@ -182,7 +182,7 @@ def launch_server(
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.1.
0
"
,
"0.1.
1
"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment