Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a31a817
Unverified
Commit
9a31a817
authored
May 16, 2024
by
Woosuk Kwon
Committed by
GitHub
May 16, 2024
Browse files
[Bugfix] Fix FP8 KV cache support (#4869)
parent
2060e936
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
26 deletions
+26
-26
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+5
-5
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+5
-5
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+5
-5
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+5
-5
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+5
-5
vllm/attention/layer.py
vllm/attention/layer.py
+1
-1
No files found.
vllm/attention/backends/flash_attn.py
View file @
9a31a817
...
@@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
...
...
vllm/attention/backends/flashinfer.py
View file @
9a31a817
...
@@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl):
...
@@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9a31a817
...
@@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
...
...
vllm/attention/backends/torch_sdpa.py
View file @
9a31a817
...
@@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
...
...
vllm/attention/backends/xformers.py
View file @
9a31a817
...
@@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
...
...
vllm/attention/layer.py
View file @
9a31a817
...
@@ -48,7 +48,7 @@ class Attention(nn.Module):
...
@@ -48,7 +48,7 @@ class Attention(nn.Module):
block_size
)
block_size
)
impl_cls
=
attn_backend
.
get_impl_cls
()
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
alibi_slopes
,
sliding_window
,
kv_cache_dtype
)
def
forward
(
def
forward
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment