Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
96853af5
Unverified
Commit
96853af5
authored
Jul 14, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 14, 2023
Browse files
Optimize MQA Kernel (#452)
parent
dbed6905
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
72 deletions
+84
-72
csrc/attention.cpp
csrc/attention.cpp
+1
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+22
-9
vllm/config.py
vllm/config.py
+7
-0
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+32
-11
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+22
-52
No files found.
csrc/attention.cpp
View file @
96853af5
...
...
@@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
...
...
csrc/attention/attention_kernels.cu
View file @
96853af5
...
...
@@ -74,14 +74,17 @@ template<
__global__
void
single_query_cached_kv_attention_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
)
{
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
...
...
@@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
kv_head_idx
=
head_mapping
[
head_idx
];
const
int
seq_idx
=
blockIdx
.
y
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
...
...
@@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_
head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
...
...
@@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
));
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
;
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_
head_idx
*
kv_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
...
...
@@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS.
template
<
...
...
@@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
...
...
@@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
query_stride
=
query
.
stride
(
0
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
...
...
@@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
head_mapping_ptr
=
reinterpret_cast
<
int
*>
(
head_mapping
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
...
...
@@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
...
...
@@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
head_mapping
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
...
...
vllm/config.py
View file @
96853af5
...
...
@@ -94,6 +94,13 @@ class ModelConfig:
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
def
get_num_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
# For GPTBigCode:
if
getattr
(
self
.
hf_config
,
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
return
1
# For Falcon:
if
getattr
(
self
.
hf_config
,
"n_head_kv"
,
None
)
is
not
None
:
return
self
.
hf_config
.
n_head_kv
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
...
...
vllm/model_executor/layers/attention.py
View file @
96853af5
...
...
@@ -44,12 +44,23 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
self
.
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
self
.
num_queries_per_kv
)
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
...
...
@@ -76,10 +87,18 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
value: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
input_metadata: metadata for paged attention.
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
dim
=
1
)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
...
...
@@ -107,9 +126,9 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_
kv_
heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_
kv_
heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
...
...
@@ -118,6 +137,7 @@ class PagedAttention(nn.Module):
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
...
...
@@ -143,11 +163,12 @@ class PagedAttention(nn.Module):
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key: shape = [num_tokens, num_
kv_
heads * head_size]
value: shape = [num_tokens, num_
kv_
heads * head_size]
key_cache: shape = [num_blocks, num_
kv_
heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
...
...
@@ -157,8 +178,8 @@ class PagedAttention(nn.Module):
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_
kv_
heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_
kv_
heads
,
self
.
head_size
)
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
96853af5
...
...
@@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
torch
import
nn
import
numpy
as
np
from
transformers
import
GPTBigCodeConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
num_kv_heads
=
1
if
config
.
multi_query
else
self
.
num_heads
self
.
kv_dim
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
...
...
@@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
@@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
...
...
@@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
def
_expand_mqa_mha
(
qkv_array
,
n_head
,
head_dim
):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array
=
qkv_array
.
numpy
()
dims_q
=
n_head
*
head_dim
# pylint: disable=unbalanced-tuple-unpacking
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
axis
=
0
)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if
k
.
ndim
==
2
and
v
.
ndim
==
2
:
replication
=
(
n_head
,
1
)
# weights
else
:
replication
=
n_head
# biases
# replicate n_head times for q, v
k
,
v
=
np
.
tile
(
k
,
replication
),
np
.
tile
(
v
,
replication
)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array
=
np
.
concatenate
((
q
,
k
,
v
),
axis
=
0
)
return
torch
.
from_numpy
(
qkv_array
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of
...
...
@@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_kv_heads
=
(
1
if
self
.
config
.
multi_query
else
total_num_heads
)
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
total_kv_size
=
head_size
*
total_num_kv_heads
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
raise
ValueError
(
f
"Unexpected parameter name
{
name
}
"
)
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
hidden_size
,
total_kv_size
,
total_kv_size
],
dim
=
0
)
wq
=
wq
[
head_size
*
head_start
:
head_size
*
head_end
]
if
not
self
.
config
.
multi_query
:
# Split the heads when using normal multi-head attention
wk
=
wk
[
head_size
*
head_start
:
head_size
*
head_end
]
wv
=
wv
[
head_size
*
head_start
:
head_size
*
head_end
]
# Else, keep the weights as is for multi-query attention
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment