Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32881f3f
Unverified
Commit
32881f3f
authored
May 02, 2024
by
Michał Moskal
Committed by
GitHub
May 02, 2024
Browse files
[kernel] fix sliding window in prefix prefill Triton kernel (#4405)
Co-authored-by:
SangBin Cho
<
rkooo567@gmail.com
>
parent
5b8a7c1c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
91 additions
and
23 deletions
+91
-23
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+30
-4
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+1
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+2
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+56
-19
No files found.
tests/kernels/test_prefix_prefill.py
View file @
32881f3f
...
...
@@ -15,6 +15,7 @@ DTYPES = [torch.float16]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
SLIDING_WINDOW
=
[
0
,
16
,
64
,
128
,
256
,
512
,
2048
]
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -22,11 +23,13 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOW
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
sliding_window
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
)
->
None
:
...
...
@@ -123,12 +126,32 @@ def test_contexted_kv_attention(
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
)
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
)
context_attention_fwd
(
query
,
k
,
v
,
output
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
...
...
@@ -156,6 +179,9 @@ def test_contexted_kv_attention(
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
subquery_lens
,
seq_lens
)
if
sliding_window
>
0
:
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
sliding_window
)
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
...
...
vllm/attention/backends/flash_attn.py
View file @
32881f3f
...
...
@@ -249,6 +249,7 @@ class FlashAttentionImpl(AttentionImpl):
prefill_meta
.
context_lens
,
prefill_meta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
32881f3f
...
...
@@ -307,6 +307,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta
.
context_lens
,
prefill_meta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
vllm/attention/backends/xformers.py
View file @
32881f3f
...
...
@@ -246,6 +246,7 @@ class XFormersImpl(AttentionImpl):
prefill_meta
.
context_lens
,
prefill_meta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
sliding_window
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
...
...
vllm/attention/ops/paged_attn.py
View file @
32881f3f
...
...
@@ -172,6 +172,7 @@ class PagedAttention:
context_lens
:
torch
.
Tensor
,
max_subquery_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
Optional
[
int
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
...
...
@@ -188,6 +189,7 @@ class PagedAttention:
context_lens
,
max_subquery_len
,
alibi_slopes
,
sliding_window
,
)
return
output
...
...
vllm/attention/ops/prefix_prefill.py
View file @
32881f3f
...
...
@@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
SLIDING_WINDOW
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
# [N]; starts at 0
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
# [D]; starts at 0
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
# [M]; starts at current position in query
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# [M,D]
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
# [D]
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
other
=
0.0
)
other
=
0.0
)
# [M,D]
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
# [M]
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
# [M]
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
# [M,D]
# compute query against context (no causal mask here)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
# [N]
# [D,N]
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_kv_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
# [N,D]
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
...
...
@@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [D,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# [M,N]
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
qk
*=
sm_scale
if
SLIDING_WINDOW
>
0
:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk
=
tl
.
where
((
cur_batch_ctx_len
+
offs_m
[:,
None
])
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
m_ij
=
tl
.
max
(
qk
,
1
)
# [M]
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
# [M,N]
l_ij
=
tl
.
sum
(
p
,
1
)
# [M]
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
# [M]
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
# [M]
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
# [M]
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# [M]
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
...
...
@@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [N,D]
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
...
...
@@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# block_mask is 0 when we're already past the current query length
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
# compute query against itself (with causal mask)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
...
...
@@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# apply causal mask
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
if
SLIDING_WINDOW
>
0
:
qk
=
tl
.
where
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
...
...
@@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_ctx_len
,
max_input_len
,
alibi_slopes
=
None
):
alibi_slopes
=
None
,
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
...
...
@@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
2
**
((
Lk
-
1
).
bit_length
()
)
Lk_padded
=
triton
.
next_power_of_2
(
Lk
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
if
sliding_window
is
not
None
else
0
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment