Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
498eb5cf
Unverified
Commit
498eb5cf
authored
Apr 03, 2024
by
Woosuk Kwon
Committed by
GitHub
Apr 04, 2024
Browse files
[Bugfix] Add kv_scale input parameter to CPU backend (#3840)
parent
537ee25f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
5 deletions
+12
-5
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+4
-2
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+3
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+4
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-1
No files found.
csrc/cpu/attention.cpp
View file @
498eb5cf
...
...
@@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
)
{
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v1_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
paged_attention_v1_impl
)
...
...
@@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>
&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
)
{
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"paged_attention_v2_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
paged_attention_v2_impl
)
...
...
csrc/cpu/cache.cpp
View file @
498eb5cf
...
...
@@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
)
{
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
)
{
TORCH_CHECK
(
kv_scale
==
1.0
f
);
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
...
...
vllm/attention/backends/torch_sdpa.py
View file @
498eb5cf
...
...
@@ -114,6 +114,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -138,7 +139,8 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
)
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
if
attn_metadata
.
is_prompt
:
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
...
...
@@ -199,6 +201,7 @@ class TorchSDPABackendImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/ops/paged_attn.py
View file @
498eb5cf
...
...
@@ -97,7 +97,7 @@ class PagedAttention:
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment