Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e8cc7967
Unverified
Commit
e8cc7967
authored
Apr 18, 2024
by
Michał Moskal
Committed by
GitHub
Apr 18, 2024
Browse files
[Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (#4128)
parent
53b018ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
18 deletions
+28
-18
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+1
-1
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+27
-17
No files found.
tests/kernels/test_prefix_prefill.py
View file @
e8cc7967
...
@@ -10,7 +10,7 @@ from vllm.attention.ops.prefix_prefill import context_attention_fwd
...
@@ -10,7 +10,7 @@ from vllm.attention.ops.prefix_prefill import context_attention_fwd
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
HEAD_SIZES
=
[
128
]
HEAD_SIZES
=
[
128
,
96
]
DTYPES
=
[
torch
.
float16
]
DTYPES
=
[
torch
.
float16
]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
...
...
vllm/attention/ops/prefix_prefill.py
View file @
e8cc7967
...
@@ -47,7 +47,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -47,7 +47,8 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
...
@@ -59,26 +60,30 @@ if triton.__version__ >= "2.1.0":
...
@@ -59,26 +60,30 @@ if triton.__version__ >= "2.1.0":
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
block_start_loc
=
BLOCK_M
*
start_m
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
_PADDED
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
dim_mask
=
tl
.
where
(
Q
+
off_q
,
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
other
=
0.0
)
# # initialize pointer to m and l
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
_PADDED
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
...
@@ -99,7 +104,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -99,7 +104,8 @@ if triton.__version__ >= "2.1.0":
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
@@ -126,7 +132,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -126,7 +132,8 @@ if triton.__version__ >= "2.1.0":
acc
=
acc
*
acc_scale
[:,
None
]
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -142,16 +149,15 @@ if triton.__version__ >= "2.1.0":
...
@@ -142,16 +149,15 @@ if triton.__version__ >= "2.1.0":
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
mask
=
dim_mask
[:,
None
]
&
cur_batch_seq_len
-
cur_batch_
ctx
_len
,
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_
query
_len
)
,
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
@@ -179,8 +185,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -179,8 +185,8 @@ if triton.__version__ >= "2.1.0":
# update acc
# update acc
v
=
tl
.
load
(
v_ptrs
+
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
mask
=
dim_mask
[
None
,
:]
&
cur_batch_seq_len
-
cur_batch_
ctx
_len
,
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_
query
_len
)
,
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -195,7 +201,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -195,7 +201,8 @@ if triton.__version__ >= "2.1.0":
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
tl
.
store
(
out_ptrs
,
acc
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
))
return
return
@
triton
.
jit
@
triton
.
jit
...
@@ -636,7 +643,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -636,7 +643,8 @@ if triton.__version__ >= "2.1.0":
# shape constraints
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
2
**
((
Lk
-
1
).
bit_length
())
sm_scale
=
1.0
/
(
Lq
**
0.5
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -646,6 +654,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -646,6 +654,7 @@ if triton.__version__ >= "2.1.0":
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
assert
Lk
==
Lk_padded
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
k
,
k
,
...
@@ -738,6 +747,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -738,6 +747,7 @@ if triton.__version__ >= "2.1.0":
num_queries_per_kv
=
num_queries_per_kv
,
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment