Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
045b5ad2
Commit
045b5ad2
authored
Jul 05, 2024
by
zhuwenwen
Browse files
skip cutlass fwop and xformers cuda backend
parent
bbf9488b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
119 additions
and
115 deletions
+119
-115
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+119
-115
No files found.
tests/kernels/test_prefix_prefill.py
View file @
045b5ad2
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
is_hip
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
...
@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
if
num_kv_heads
!=
num_heads
:
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# project the key and value tensors to the desired number of
# heads.
# heads.
#
#
# see also: vllm/model_executor/layers/attention.py
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
)
query_lens
,
seq_lens
)
if
sliding_window
>
0
:
if
sliding_window
>
0
:
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
sliding_window
)
sliding_window
)
output_ref
=
xops
.
memory_efficient_attention_forward
(
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
,
query
,
key
,
key
,
value
,
value
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
attn_op
,
op
=
attn_op
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
output_ref
=
xops
.
memory_efficient_attention_forward
(
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
,
query
,
key
,
key
,
value
,
value
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
attn_op
,
op
=
attn_op
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
...
@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if
query
.
shape
[
0
]
!=
key
.
shape
[
0
]:
query_pad
=
torch
.
empty
(
sum
(
seq_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
query_pad
.
uniform_
(
-
1e-3
,
1e-3
)
seq_start
=
0
query_start
=
0
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_pad
[
seq_start
:
seq_end
,
...]
=
torch
.
cat
([
torch
.
zeros
(
seq_len
-
query_len
,
num_heads
,
head_size
,
dtype
=
dtype
),
query
[
query_start
:
query_end
,
...]
],
dim
=
0
)
seq_start
+=
seq_len
query_start
+=
query_len
query
=
query_pad
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
# we have to pad query tensor before MQA/GQA expanding.
output_ref
=
torch
.
empty_like
(
output
)
if
query
.
shape
[
0
]
!=
key
.
shape
[
0
]:
query_pad
=
torch
.
empty
(
sum
(
seq_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
query_pad
.
uniform_
(
-
1e-3
,
1e-3
)
seq_start
=
0
seq_start
=
0
query_start
=
0
query_start
=
0
start_time
=
time
.
time
()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_end
=
query_start
+
query_len
query_pad
[
seq_start
:
seq_end
,
...]
=
torch
.
cat
([
out
=
xops
.
memory_efficient_attention_forward
(
query
[:,
torch
.
zeros
(
seq_start
:
seq_end
],
seq_len
-
query_len
,
num_heads
,
head_size
,
dtype
=
dtype
),
key
[:,
query
[
query_start
:
query_end
,
...]
seq_start
:
seq_end
],
],
value
[:,
dim
=
0
)
seq_start
:
seq_end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
out
=
out
.
view_as
(
query
[:,
seq_start
:
seq_end
]).
view
(
seq_len
,
num_heads
,
head_size
)
output_ref
[
query_start
:
query_end
,
...].
copy_
(
out
[
seq_len
-
query_len
:,
...])
seq_start
+=
seq_len
seq_start
+=
seq_len
query_start
+=
query_len
query_start
+=
query_len
query
=
query_pad
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
if
num_kv_heads
!=
num_heads
:
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output_ref
=
torch
.
empty_like
(
output
)
seq_start
=
0
query_start
=
0
start_time
=
time
.
time
()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[:,
seq_start
:
seq_end
],
key
[:,
seq_start
:
seq_end
],
value
[:,
seq_start
:
seq_end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
out
=
out
.
view_as
(
query
[:,
seq_start
:
seq_end
]).
view
(
seq_len
,
num_heads
,
head_size
)
output_ref
[
query_start
:
query_end
,
...].
copy_
(
out
[
seq_len
-
query_len
:,
...])
seq_start
+=
seq_len
query_start
+=
query_len
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment