Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cfbb8c93
Unverified
Commit
cfbb8c93
authored
Mar 21, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 21, 2025
Browse files
[TPU][V1] MHA Pallas backend (#15288)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
baec0d4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
117 additions
and
2 deletions
+117
-2
tests/v1/tpu/test_mha_attn.py
tests/v1/tpu/test_mha_attn.py
+109
-0
vllm/attention/layer.py
vllm/attention/layer.py
+8
-2
No files found.
tests/v1/tpu/test_mha_attn.py
0 → 100644
View file @
cfbb8c93
# SPDX-License-Identifier: Apache-2.0
"""
Test:
* Tests for MultiHeadAttention layer
"""
import
pytest
import
torch
import
torch_xla
import
torch_xla.core
import
torch_xla.core.xla_model
from
vllm
import
envs
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend
.
cache_clear
()
def
ref_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
attn_weights
=
scale
*
torch
.
matmul
(
query
,
key
.
transpose
(
2
,
3
))
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
matmul
(
attn_weights
,
value
).
transpose
(
1
,
2
)
return
out
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
torch_xla
.
core
.
xla_model
.
xla_device
()])
def
test_mha_attn_forward
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
device
:
str
,
):
current_platform
.
seed_everything
(
0
)
# These are expected to be f32
q
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
device
=
device
)
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MultiHeadAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
q
=
q
.
reshape
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
k
=
k
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
v
=
v
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
ref_attention
(
q
,
k
,
v
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
# torch_xla flash_attn kernel is less accurate but much faster
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-3
)
vllm/attention/layer.py
View file @
cfbb8c93
...
@@ -281,8 +281,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -281,8 +281,7 @@ class MultiHeadAttention(nn.Module):
backend
=
_Backend
.
XFORMERS
backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
backend
if
backend
in
{
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
_Backend
.
XFORMERS
,
}
else
_Backend
.
TORCH_SDPA
}
else
_Backend
.
TORCH_SDPA
def
forward
(
def
forward
(
...
@@ -320,6 +319,13 @@ class MultiHeadAttention(nn.Module):
...
@@ -320,6 +319,13 @@ class MultiHeadAttention(nn.Module):
value
,
value
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
transpose
(
1
,
2
)
elif
self
.
attn_backend
==
_Backend
.
PALLAS_VLLM_V1
:
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
from
torch_xla.experimental.custom_kernel
import
flash_attention
out
=
flash_attention
(
query
,
key
,
value
,
sm_scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
return
out
.
reshape
(
bsz
,
q_len
,
-
1
)
return
out
.
reshape
(
bsz
,
q_len
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment