Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
71bcaf99
Unverified
Commit
71bcaf99
authored
Feb 27, 2024
by
Tao He
Committed by
GitHub
Feb 27, 2024
Browse files
Enable GQA support in the prefix prefill kernels (#3007)
Signed-off-by:
Tao He
<
sighingnow@gmail.com
>
parent
8b430d7d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
47 deletions
+87
-47
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+42
-19
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+18
-16
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
+27
-12
No files found.
tests/kernels/test_prefix_prefill.py
View file @
71bcaf99
...
...
@@ -8,7 +8,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
NUM_HEADS
=
[
12
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
HEAD_SIZES
=
[
128
]
DTYPES
=
[
torch
.
float16
]
CUDA_DEVICES
=
[
...
...
@@ -17,12 +18,14 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_queries_per_kv"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
...
...
@@ -41,28 +44,29 @@ def test_contexted_kv_attention(
subquery_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
subquery_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
subquery_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1e-3
,
1e-3
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1e-3
,
1e-3
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_heads
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_heads
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
BS
*
max_block_per_request
].
view
(
...
...
@@ -93,19 +97,21 @@ def test_contexted_kv_attention(
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache
=
k_cache
.
view
(
-
1
,
block_size
,
num_heads
,
head_size
//
8
,
k_cache
=
k_cache
.
view
(
-
1
,
block_size
,
num_
kv_
heads
,
head_size
//
8
,
8
).
permute
(
0
,
2
,
3
,
1
,
4
).
contiguous
()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_heads
,
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_
kv_
heads
,
head_size
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# Warm up the Triton kernel by calling it once before actually measuring generation time
...
...
@@ -123,12 +129,29 @@ def test_contexted_kv_attention(
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
subquery_lens
,
seq_lens
)
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
)
,
key
.
unsqueeze
(
0
)
,
value
.
unsqueeze
(
0
)
,
query
,
key
,
value
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
...
...
@@ -137,9 +160,9 @@ def test_contexted_kv_attention(
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
)
,
key
.
unsqueeze
(
0
)
,
value
.
unsqueeze
(
0
)
,
query
,
key
,
value
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
...
...
@@ -148,5 +171,5 @@ def test_contexted_kv_attention(
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
squeeze
(
0
)
output_ref
=
output_ref
.
squeeze
(
0
,
2
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
vllm/model_executor/layers/attention.py
View file @
71bcaf99
...
...
@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
)
if
input_metadata
.
is_prompt
:
# Prompt run.
# normal attention
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# normal attention
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
...
...
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
View file @
71bcaf99
...
...
@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
...
...
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
...
...
@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment