Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
72f48804
Unverified
Commit
72f48804
authored
Jan 26, 2025
by
Tyler Michael Smith
Committed by
GitHub
Jan 26, 2025
Browse files
[Bugfix/CI] Fix broken kernels/test_mha.py (#12450)
parent
aa2cd2c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
2 deletions
+10
-2
tests/kernels/test_mha_attn.py
tests/kernels/test_mha_attn.py
+2
-2
vllm/attention/layer.py
vllm/attention/layer.py
+8
-0
No files found.
tests/kernels/test_mha_attn.py
View file @
72f48804
...
@@ -26,7 +26,7 @@ def clear_cache():
...
@@ -26,7 +26,7 @@ def clear_cache():
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_mha_attn_platform
(
device
:
str
):
def
test_mha_attn_platform
(
device
:
str
):
"""
"""
Test
that
the attention selector between different platform and device.
Test the attention selector between different platform and device.
"""
"""
torch
.
set_default_dtype
(
torch
.
float16
)
torch
.
set_default_dtype
(
torch
.
float16
)
...
@@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
...
@@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
else
:
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
FLASH_ATTN
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
...
...
vllm/attention/layer.py
View file @
72f48804
...
@@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
...
@@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
self
.
scale
=
scale
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
dtype
,
...
@@ -240,6 +243,11 @@ class MultiHeadAttention(nn.Module):
...
@@ -240,6 +243,11 @@ class MultiHeadAttention(nn.Module):
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment