Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8b317c6d
Unverified
Commit
8b317c6d
authored
Apr 10, 2024
by
James Whedbee
Committed by
GitHub
Apr 10, 2024
Browse files
[Model][AMD] ROCm support for 256 head dims for Gemma (#3972)
parent
bd3c144e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
3 deletions
+2
-3
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+2
-3
No files found.
vllm/attention/ops/triton_flash_attention.py
View file @
8b317c6d
...
@@ -677,8 +677,7 @@ def check_args(
...
@@ -677,8 +677,7 @@ def check_args(
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
# TODO: Change assert if we support qkl f8 and v f16
# TODO: Change assert if we support qkl f8 and v f16
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
# TODO: Fix assert to check head size <=256 once supported
assert
head_size
<=
256
assert
head_size
<=
128
assert
o
.
shape
==
q
.
shape
assert
o
.
shape
==
q
.
shape
assert
(
nheads_q
%
nheads_k
)
==
0
assert
(
nheads_q
%
nheads_k
)
==
0
...
@@ -729,7 +728,7 @@ class _attention(torch.autograd.Function):
...
@@ -729,7 +728,7 @@ class _attention(torch.autograd.Function):
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
3
))
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
3
))
# Get closest power of 2 over or equal to 32.
# Get closest power of 2 over or equal to 32.
unpadded_head_dims
=
{
32
,
64
,
128
}
unpadded_head_dims
=
{
32
,
64
,
128
,
256
}
if
head_size
not
in
unpadded_head_dims
:
if
head_size
not
in
unpadded_head_dims
:
padded_d_model
=
None
padded_d_model
=
None
for
i
in
unpadded_head_dims
:
for
i
in
unpadded_head_dims
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment