Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ee10e99
Unverified
Commit
5ee10e99
authored
Mar 06, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 05, 2025
Browse files
[Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (#11301)
parent
3dbd2d81
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
22 deletions
+83
-22
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+78
-17
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+5
-3
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+0
-2
No files found.
tests/kernels/test_attention.py
View file @
5ee10e99
...
...
@@ -17,6 +17,8 @@ if not current_platform.is_rocm():
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
...
...
@@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
alibi_bias
:
Optional
[
list
[
torch
.
Tensor
]],
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
:
list
[
torch
.
Tensor
]
=
[]
if
alibi_bias
:
assert
len
(
alibi_bias
)
==
num_seqs
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
# Create attention mask. ALiBi already includes a tril causal mask.
if
alibi_bias
:
attn_mask
=
alibi_bias
[
i
]
else
:
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
...
...
@@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
return
torch
.
cat
(
ref_outputs
,
dim
=
0
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
use_alibi
:
bool
=
False
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
)
output
=
output
.
squeeze
(
0
)
alibi_bias
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output
=
torch
.
empty_like
(
query
)
start
=
0
# Dynamic sequence length not supported with custom attn_bias.
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
output
[
start
:
end
].
copy_
(
out
.
view_as
(
query
[
start
:
end
]))
start
+=
seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias
=
[
b
.
materialize
(
b
.
shape
,
device
=
device
).
squeeze
()
for
b
in
attn_bias
]
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
)
output
=
output
.
squeeze
(
0
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
...
...
@@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
key
,
value
,
scale
,
alibi_bias
,
dtype
,
)
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
\ No newline at end of file
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention_with_alibi
(
num_seqs
:
int
,
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
return
test_multi_query_kv_attention
(
num_seqs
,
num_heads
,
head_size
,
dtype
,
seed
,
device
,
use_alibi
=
True
,
)
tests/kernels/test_prefix_prefill.py
View file @
5ee10e99
...
...
@@ -439,14 +439,16 @@ def test_contexted_kv_attention_alibi(
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key
=
key
.
reshape
(
key
.
shape
[
0
],
-
1
,
key
.
shape
[
-
1
])
value
=
value
.
reshape
(
value
.
shape
[
0
],
-
1
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
...
...
vllm/attention/backends/xformers.py
View file @
5ee10e99
...
...
@@ -788,8 +788,6 @@ def _make_alibi_bias(
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias
))
return
attn_biases
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment