Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a157cc8c
Commit
a157cc8c
authored
Jul 22, 2023
by
Tri Dao
Browse files
[FT] Implement MQA/GQA
parent
75e334d4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
30 deletions
+50
-30
csrc/ft_attention/decoder_masked_multihead_attention.h
csrc/ft_attention/decoder_masked_multihead_attention.h
+5
-1
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+27
-19
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+18
-10
No files found.
csrc/ft_attention/decoder_masked_multihead_attention.h
View file @
a157cc8c
...
...
@@ -69,7 +69,9 @@ struct Multihead_attention_params_base {
const
int
*
cache_indir
=
nullptr
;
// Stride to handle the case when KQV is a single buffer
int
stride
=
0
;
int
stride_q
=
0
;
int
stride_k
=
0
;
int
stride_v
=
0
;
// The batch size.
int
batch_size
=
0
;
...
...
@@ -79,6 +81,8 @@ struct Multihead_attention_params_base {
int
memory_max_len
=
0
;
// The number of heads (H).
int
num_heads
=
0
;
int
num_heads_kv
=
0
;
int
num_heads_q_kv_ratio
=
0
;
// The hidden dimension per head (Dh).
int
hidden_size_per_head
=
0
;
// The per-head latent space reserved for rotary embeddings.
...
...
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
View file @
a157cc8c
...
...
@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The head.
// const int hi = blockIdx.x;
const
int
hi
=
params
.
nnz_head_idx
==
nullptr
?
blockIdx
.
x
:
params
.
nnz_head_idx
[
blockIdx
.
x
];
const
int
hi_kv
=
hi
/
params
.
num_heads_q_kv_ratio
;
// Combine the batch and the head indices.
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
const
int
bhi_kv
=
bi
*
params
.
num_heads_kv
+
hi_kv
;
// Combine the "beam-aware" batch idx and the head indices.
const
int
bbhi
=
bbi
*
params
.
beam_width
*
params
.
num_heads
+
hi
;
const
int
bbhi
=
bbi
*
params
.
beam_width
*
params
.
num_heads
_kv
+
hi
_kv
;
// The thread in the block.
const
int
tidx
=
threadIdx
.
x
;
...
...
@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
float
qk
=
0.0
F
;
int
qkv_base_offset
=
(
params
.
stride
==
0
)
?
bhi
*
Dh
:
bi
*
params
.
stride
+
hi
*
Dh
;
int
q_base_offset
=
(
params
.
stride_q
==
0
)
?
bhi
*
Dh
:
bi
*
params
.
stride_q
+
hi
*
Dh
;
int
k_base_offset
=
(
params
.
stride_k
==
0
)
?
bhi_kv
*
Dh
:
bi
*
params
.
stride_k
+
hi_kv
*
Dh
;
int
v_base_offset
=
(
params
.
stride_v
==
0
)
?
bhi_kv
*
Dh
:
bi
*
params
.
stride_v
+
hi_kv
*
Dh
;
const
size_t
bi_seq_len_offset
=
bi
*
params
.
memory_max_len
;
...
...
@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const
bool
is_masked
=
tidx
>=
QK_VECS_PER_WARP
;
// The offset in the Q and K buffer also accounts for the batch.
int
qk_offset
=
qkv_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
q_offset
=
q_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
k_offset
=
k_base_offset
+
tidx
*
QK_VEC_SIZE
;
// The offset in the bias buffer.
int
qk_bias_offset
=
hi
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
q_bias_offset
=
hi
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
k_bias_offset
=
hi_kv
*
Dh
+
tidx
*
QK_VEC_SIZE
;
const
bool
do_ia3
=
handle_kv
&&
params
.
ia3_tasks
!=
nullptr
;
const
int
ia3_task_id
=
do_ia3
?
params
.
ia3_tasks
[
bbi
]
:
0
;
...
...
@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
q_scaling
=
params
.
qkv_scale_out
[
0
];
const
auto
q_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
q
)[
q
k
_offset
]);
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
q
)[
q_offset
]);
convert_from_float
(
q
,
mul
<
Packed_Float_t
,
float
>
(
q_scaling
,
float_from_int8
(
q_quant
)));
}
else
{
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q
[
q
k
_offset
]);
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q
[
q_offset
]);
}
}
...
...
@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
int
offset
=
bhi
_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength
*
QK_ELTS_IN_16B
+
ci
;
k
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
?
...
...
@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
k_scaling
=
params
.
qkv_scale_out
[
1
];
const
auto
k_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
k
)[
q
k_offset
]);
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
k
)[
k_offset
]);
convert_from_float
(
k
,
mul
<
Packed_Float_t
,
float
>
(
k_scaling
,
float_from_int8
(
k_quant
)));
}
else
{
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k
[
q
k_offset
]);
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k
[
k_offset
]);
}
}
}
...
...
@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
!
is_masked
&&
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
q_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q_bias
[
q
k
_bias_offset
])
:
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q_bias
[
q_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
if
(
handle_kv
)
{
k_bias
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
k_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_bias
[
q
k_bias_offset
])
:
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_bias
[
k_bias_offset
])
:
k_bias
;
}
...
...
@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
int
offset
=
bhi
_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
if
(
handle_kv
)
{
if
(
handle_kv
&&
hi
%
params
.
num_heads_q_kv_ratio
==
0
)
{
// Trigger the stores to global memory.
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
=
k
;
...
...
@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
// The base pointer for the key in the cache buffer.
T
*
k_cache
=
&
params
.
k_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
T
*
k_cache
=
&
params
.
k_cache
[
bhi
_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
k_cache_batch
=
&
params
.
k_cache
[
bbhi
*
params
.
memory_max_len
*
Dh
+
ki
];
...
...
@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int
vi
=
tidx
%
THREADS_PER_VALUE
*
V_VEC_SIZE
;
// The base pointer for the value in the cache buffer.
T
*
v_cache
=
&
params
.
v_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
vi
];
T
*
v_cache
=
&
params
.
v_cache
[
bhi
_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
v_cache_batch
=
&
params
.
v_cache
[
bbhi
*
params
.
memory_max_len
*
Dh
+
vi
];
...
...
@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if
(
vo
==
tlength
%
V_PER_ITER
)
{
// Trigger the loads from the V bias buffer.
if
(
params
.
v_bias
!=
nullptr
)
{
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v_bias
[
hi
*
Dh
+
vi
]);
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v_bias
[
hi
_kv
*
Dh
+
vi
]);
}
if
(
DO_CROSS_ATTENTION
)
{
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
])
=
v_bias
;
...
...
@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
}
else
{
// Trigger the loads from the V buffer.
const
auto
v_offset
=
qk
v_base_offset
+
vi
;
const
auto
v_offset
=
v_base_offset
+
vi
;
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
V_vec
>::
value
>::
type
;
...
...
@@ -1539,9 +1545,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
}
// Store the values with bias back to global memory in the cache for V.
if
(
hi
%
params
.
num_heads_q_kv_ratio
==
0
)
{
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
v
;
}
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...
...
csrc/ft_attention/ft_attention.cpp
View file @
a157cc8c
...
...
@@ -50,13 +50,16 @@ template <typename T>
void
set_params
(
Masked_multihead_attention_params
<
T
>
&
params
,
const
size_t
batch_size
,
const
size_t
nheads
,
const
size_t
nheads_kv
,
const
size_t
memory_max_seqlen
,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
const
int
q_batch_stride
,
const
int
k_batch_stride
,
const
int
v_batch_stride
,
const
int
nnz_heads
,
T
*
q_ptr
,
T
*
k_ptr
,
...
...
@@ -80,11 +83,15 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
v_cache
=
v_cache_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
qkv_batch_stride
;
params
.
stride_q
=
q_batch_stride
;
params
.
stride_k
=
k_batch_stride
;
params
.
stride_v
=
v_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_heads_kv
=
nheads_kv
;
params
.
num_heads_q_kv_ratio
=
nheads
/
nheads_kv
;
params
.
nnz_heads
=
nnz_heads
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
...
...
@@ -124,23 +131,23 @@ torch::Tensor single_query_attention(const torch::Tensor q,
const
bool
neox_rotary_style
=
true
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
int
nheads
=
v_cache
.
size
(
1
);
int
nheads
=
q
.
size
(
1
);
int
nheads_kv
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
auto
input_type
=
q
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
_kv
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
_kv
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads
_kv
,
memory_max_seqlen
,
headdim
);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads
_kv
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
q
.
stride
(
0
)
==
k
.
stride
(
0
)
&&
q
.
stride
(
0
)
==
v
.
stride
(
0
));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
TORCH_CHECK
(
q
.
scalar_type
()
==
input_type
);
...
...
@@ -191,8 +198,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
set_params
(
params
,
batch_size
,
nheads
,
nheads_kv
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
k
.
stride
(
0
),
v
.
stride
(
0
),
nnz_head_idx_
.
has_value
()
?
nnz_head_idx_
.
value
().
size
(
0
)
:
0
,
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment