Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
96853af5
"mmdet3d/vscode:/vscode.git/clone" did not exist on "a800db23b8a7f69a95f972dc94d6f3ced29631f0"
Unverified
Commit
96853af5
authored
Jul 14, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 14, 2023
Browse files
Optimize MQA Kernel (#452)
parent
dbed6905
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
72 deletions
+84
-72
csrc/attention.cpp
csrc/attention.cpp
+1
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+22
-9
vllm/config.py
vllm/config.py
+7
-0
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+32
-11
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+22
-52
No files found.
csrc/attention.cpp
View file @
96853af5
...
...
@@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
...
...
csrc/attention/attention_kernels.cu
View file @
96853af5
...
...
@@ -74,14 +74,17 @@ template<
__global__
void
single_query_cached_kv_attention_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_heads, head_size, block_size]
const
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads, head_size/x, block_size, x]
const
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads, head_size, block_size]
const
int
*
__restrict__
head_mapping
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
)
{
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
)
{
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
...
...
@@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
kv_head_idx
=
head_mapping
[
head_idx
];
const
int
seq_idx
=
blockIdx
.
y
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
...
...
@@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
const
scalar_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_
head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
...
...
@@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
));
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
num_heads
*
HEAD_SIZE
*
BLOCK_SIZE
+
head_idx
*
HEAD_SIZE
*
BLOCK_SIZE
;
const
scalar_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_
head_idx
*
kv_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
...
...
@@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS.
template
<
...
...
@@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
...
...
@@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
int
query_stride
=
query
.
stride
(
0
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
...
...
@@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
value_cache_ptr
=
reinterpret_cast
<
T
*>
(
value_cache
.
data_ptr
());
int
*
head_mapping_ptr
=
reinterpret_cast
<
int
*>
(
head_mapping
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context_lens_ptr
=
context_lens
.
data_ptr
<
int
>
();
...
...
@@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
...
...
@@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
head_mapping
,
// [num_heads]
float
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
...
...
vllm/config.py
View file @
96853af5
...
...
@@ -94,6 +94,13 @@ class ModelConfig:
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
def
get_num_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
# For GPTBigCode:
if
getattr
(
self
.
hf_config
,
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
return
1
# For Falcon:
if
getattr
(
self
.
hf_config
,
"n_head_kv"
,
None
)
is
not
None
:
return
self
.
hf_config
.
n_head_kv
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
...
...
vllm/model_executor/layers/attention.py
View file @
96853af5
...
...
@@ -44,12 +44,23 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
self
.
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
self
.
num_queries_per_kv
)
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
...
...
@@ -76,10 +87,18 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
value: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
input_metadata: metadata for paged attention.
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
dim
=
1
)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
...
...
@@ -107,9 +126,9 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_
kv_
heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_
kv_
heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
...
...
@@ -118,6 +137,7 @@ class PagedAttention(nn.Module):
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
...
...
@@ -143,11 +163,12 @@ class PagedAttention(nn.Module):
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key: shape = [num_tokens, num_
kv_
heads * head_size]
value: shape = [num_tokens, num_
kv_
heads * head_size]
key_cache: shape = [num_blocks, num_
kv_
heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
...
...
@@ -157,8 +178,8 @@ class PagedAttention(nn.Module):
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_
kv_
heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_
kv_
heads
,
self
.
head_size
)
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
96853af5
...
...
@@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
torch
import
nn
import
numpy
as
np
from
transformers
import
GPTBigCodeConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
num_kv_heads
=
1
if
config
.
multi_query
else
self
.
num_heads
self
.
kv_dim
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
...
...
@@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
perform_initialization
=
False
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
@@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
...
...
@@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
def
_expand_mqa_mha
(
qkv_array
,
n_head
,
head_dim
):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array
=
qkv_array
.
numpy
()
dims_q
=
n_head
*
head_dim
# pylint: disable=unbalanced-tuple-unpacking
q
,
k
,
v
=
np
.
split
(
qkv_array
,
(
dims_q
,
dims_q
+
head_dim
),
axis
=
0
)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if
k
.
ndim
==
2
and
v
.
ndim
==
2
:
replication
=
(
n_head
,
1
)
# weights
else
:
replication
=
n_head
# biases
# replicate n_head times for q, v
k
,
v
=
np
.
tile
(
k
,
replication
),
np
.
tile
(
v
,
replication
)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array
=
np
.
concatenate
((
q
,
k
,
v
),
axis
=
0
)
return
torch
.
from_numpy
(
qkv_array
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of
...
...
@@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_kv_heads
=
(
1
if
self
.
config
.
multi_query
else
total_num_heads
)
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
total_kv_size
=
head_size
*
total_num_kv_heads
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
_expand_mqa_mha
(
loaded_weight
,
n_head
=
total_num_heads
,
head_dim
=
head_size
)
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
raise
ValueError
(
f
"Unexpected parameter name
{
name
}
"
)
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
hidden_size
,
total_kv_size
,
total_kv_size
],
dim
=
0
)
wq
=
wq
[
head_size
*
head_start
:
head_size
*
head_end
]
if
not
self
.
config
.
multi_query
:
# Split the heads when using normal multi-head attention
wk
=
wk
[
head_size
*
head_start
:
head_size
*
head_end
]
wv
=
wv
[
head_size
*
head_start
:
head_size
*
head_end
]
# Else, keep the weights as is for multi-query attention
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment