Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
71bcaf99
Unverified
Commit
71bcaf99
authored
Feb 27, 2024
by
Tao He
Committed by
GitHub
Feb 27, 2024
Browse files
Enable GQA support in the prefix prefill kernels (#3007)
Signed-off-by:
Tao He
<
sighingnow@gmail.com
>
parent
8b430d7d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
47 deletions
+87
-47
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+42
-19
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+18
-16
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
+27
-12
No files found.
tests/kernels/test_prefix_prefill.py
View file @
71bcaf99
...
...
@@ -8,7 +8,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
NUM_HEADS
=
[
12
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
HEAD_SIZES
=
[
128
]
DTYPES
=
[
torch
.
float16
]
CUDA_DEVICES
=
[
...
...
@@ -17,12 +18,14 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_queries_per_kv"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
...
...
@@ -41,28 +44,29 @@ def test_contexted_kv_attention(
subquery_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
subquery_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
num_tokens
=
sum
(
subquery_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1e-3
,
1e-3
)
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1e-3
,
1e-3
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_heads
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_heads
,
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
subquery_lens
),
num_
kv_
heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
BS
*
max_block_per_request
].
view
(
...
...
@@ -93,19 +97,21 @@ def test_contexted_kv_attention(
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache
=
k_cache
.
view
(
-
1
,
block_size
,
num_heads
,
head_size
//
8
,
k_cache
=
k_cache
.
view
(
-
1
,
block_size
,
num_
kv_
heads
,
head_size
//
8
,
8
).
permute
(
0
,
2
,
3
,
1
,
4
).
contiguous
()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_heads
,
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_
kv_
heads
,
head_size
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# Warm up the Triton kernel by calling it once before actually measuring generation time
...
...
@@ -123,12 +129,29 @@ def test_contexted_kv_attention(
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
subquery_lens
,
seq_lens
)
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
)
,
key
.
unsqueeze
(
0
)
,
value
.
unsqueeze
(
0
)
,
query
,
key
,
value
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
...
...
@@ -137,9 +160,9 @@ def test_contexted_kv_attention(
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
)
,
key
.
unsqueeze
(
0
)
,
value
.
unsqueeze
(
0
)
,
query
,
key
,
value
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
...
...
@@ -148,5 +171,5 @@ def test_contexted_kv_attention(
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
squeeze
(
0
)
output_ref
=
output_ref
.
squeeze
(
0
,
2
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
vllm/model_executor/layers/attention.py
View file @
71bcaf99
...
...
@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
)
if
input_metadata
.
is_prompt
:
# Prompt run.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# normal attention
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
...
...
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
View file @
71bcaf99
...
...
@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
...
...
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
...
...
@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment