Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1fc0510
Unverified
Commit
f1fc0510
authored
Jan 25, 2025
by
Isotr0py
Committed by
GitHub
Jan 25, 2025
Browse files
[Misc] Add FA2 support to ViT MHA layer (#12355)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
bf21481d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
5 deletions
+146
-5
tests/kernels/test_mha_attn.py
tests/kernels/test_mha_attn.py
+126
-0
vllm/attention/layer.py
vllm/attention/layer.py
+20
-5
No files found.
tests/kernels/test_mha_attn.py
0 → 100644
View file @
f1fc0510
"""
Test:
* Tests for MultiHeadAttention layer
"""
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_Backend
,
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_mha_attn_platform
(
device
:
str
):
"""
Test that the attention selector between different platform and device.
"""
torch
.
set_default_dtype
(
torch
.
float16
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
FLASH_ATTN
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
def
ref_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
attn_weights
=
scale
*
torch
.
matmul
(
query
,
key
.
transpose
(
2
,
3
))
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
matmul
(
attn_weights
,
value
).
transpose
(
1
,
2
)
return
out
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
half
,
torch
.
bfloat16
]
CUDA_DEVICES
=
[
"cuda"
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_forward
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
q
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MultiHeadAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
q
=
q
.
reshape
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
k
=
k
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
v
=
v
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
ref_attention
(
q
,
k
,
v
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
vllm/attention/layer.py
View file @
f1fc0510
...
...
@@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
...
...
@@ -217,11 +220,12 @@ class MultiHeadAttention(nn.Module):
block_size
=
16
,
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}
else
_Backend
.
TORCH_SDPA
def
forward
(
...
...
@@ -231,7 +235,6 @@ class MultiHeadAttention(nn.Module):
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz
,
q_len
,
_
=
query
.
size
()
kv_len
=
key
.
size
(
1
)
...
...
@@ -239,7 +242,19 @@ class MultiHeadAttention(nn.Module):
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}:
from
vllm.vllm_flash_attn
import
flash_attn_func
out
=
flash_attn_func
(
query
,
key
,
value
,
softmax_scale
=
self
.
scale
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment