Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ee77b931
Commit
ee77b931
authored
Sep 10, 2023
by
Tri Dao
Browse files
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
parent
07005806
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
14 deletions
+29
-14
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+23
-8
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+1
-1
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+1
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+1
-1
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+3
-3
No files found.
csrc/flash_attn/flash_api.cpp
View file @
ee77b931
...
...
@@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
mha_fwd_kvcache
(
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_
q
x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_
q
x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_
knew
x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_
knew
x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
bool
is_causal
,
bool
is_causal
,
int
num_splits
)
{
...
...
@@ -1032,8 +1032,8 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
const
auto
sizes
=
q
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
seqlen_q
=
sizes
[
1
];
const
int
num_heads
=
sizes
[
2
];
int
seqlen_q
=
sizes
[
1
];
int
num_heads
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
...
...
@@ -1041,6 +1041,15 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
if
(
seqlen_q
==
1
)
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const
int
seqlenq_nheads_swapped
=
seqlen_q
==
1
&&
num_heads_k
==
1
&&
num_heads
>
1
;
if
(
seqlenq_nheads_swapped
)
{
q
=
q
.
transpose
(
1
,
2
);
std
::
swap
(
seqlen_q
,
num_heads
);
}
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
...
...
@@ -1111,8 +1120,9 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
TORCH_CHECK
(
v
.
is_cuda
(),
"Value tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Key tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Value tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
int
seqlen_knew
=
k
.
size
(
1
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_knew
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_knew
,
num_heads_k
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
...
...
@@ -1120,6 +1130,7 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
k_padded
=
k
;
v_padded
=
v
;
}
params
.
seqlen_knew
=
seqlen_knew
;
params
.
knew_ptr
=
k_padded
.
data_ptr
();
params
.
vnew_ptr
=
v_padded
.
data_ptr
();
// All stride are in elements, not bytes.
...
...
@@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
}
}
if
(
seqlenq_nheads_swapped
)
{
out
=
out
.
transpose
(
1
,
2
);
softmax_lse
=
softmax_lse
.
transpose
(
1
,
2
);
}
return
{
out
,
softmax_lse
};
}
...
...
csrc/flash_attn/src/block_info.h
View file @
ee77b931
...
...
@@ -19,7 +19,7 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_
q
))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_
knew
))
{
}
...
...
csrc/flash_attn/src/flash.h
View file @
ee77b931
...
...
@@ -68,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params {
void
*
__restrict__
softmax_lseaccum_ptr
;
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
int
b
,
seqlen_q
,
seqlen_k
,
seqlen_knew
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
// The scaling factors for the kernel.
float
scale_softmax
;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
ee77b931
...
...
@@ -644,7 +644,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_
q
= %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_
q
)); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_
knew
= %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_
knew
)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
...
...
flash_attn/flash_attn_interface.py
View file @
ee77b931
...
...
@@ -838,9 +838,9 @@ def flash_attn_with_kvcache(
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate
k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
k [optional]: (batch_size, seqlen
_new
, nheads_k, headdim). If not None, we concatenate
k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen
_new
, nheads_k, headdim). Similar to k.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment