Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a7b66ae2
Commit
a7b66ae2
authored
Jan 13, 2024
by
Tri Dao
Browse files
Simplify writing softmax to gmem
parent
8d1b169e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
107 deletions
+46
-107
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+23
-47
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+8
-23
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+11
-11
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-26
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
a7b66ae2
...
@@ -56,23 +56,6 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
...
@@ -56,23 +56,6 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
>
inline
__device__
void
write_softmax_to_gmem
(
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_tiled_copy_P
)
{
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout
l
=
tOrP
.
layout
();
Tensor
tPrP
=
make_tensor
(
tOrP
.
data
(),
make_layout
(
get
<
0
>
(
l
),
make_layout
(
get
<
1
>
(
l
),
get
<
2
>
(
l
))));
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tPgP
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tPrP
)
==
size
<
1
>
(
tPgP
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tPrP
);
++
mi
)
{
cute
::
copy
(
gmem_tiled_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
...
@@ -92,6 +75,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -92,6 +75,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
unsigned
long
long
seed
=
std
::
get
<
0
>
(
seeds
);
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
// exit early and no one saves the rng states.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
seed
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
...
@@ -107,13 +101,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -107,13 +101,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
// Otherwise we might read OOB elements from gK and gV.
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
...
@@ -188,8 +175,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -188,8 +175,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyP
gmem_tiled_copy_P
;
auto
gmem_thr_copy_P
=
gmem_tiled_copy_P
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
...
@@ -197,7 +182,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -197,7 +182,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tPgP
=
gmem_thr_copy_P
.
partition_D
(
gP
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
...
@@ -205,6 +189,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -205,6 +189,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
//
//
...
@@ -310,16 +296,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -310,16 +296,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
unsigned
long
long
seed
=
std
::
get
<
0
>
(
seeds
);
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
// Save seed and offset for backward.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
seed
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
clear
(
acc_o
);
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
...
@@ -429,14 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -429,14 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
Tensor
tOrPdrop
=
make_tensor
(
acc_s_f16
.
data
(),
tOrP
.
layout
()
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP
_c
op
y
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
tOrP
dr
op
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_tiled_copy_P
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
t
PgP
.
data
()
=
t
PgP
.
data
()
+
(
-
kBlockN
);
t
SgS
.
data
()
=
t
SgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
@@ -514,14 +490,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -514,14 +490,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
Tensor
tOrPdrop
=
make_tensor
(
acc_s_f16
.
data
(),
tOrP
.
layout
()
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP
_c
op
y
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
tOrP
dr
op
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_tiled_copy_P
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
t
PgP
.
data
()
=
t
PgP
.
data
()
+
(
-
kBlockN
);
t
SgS
.
data
()
=
t
SgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
a7b66ae2
...
@@ -106,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -106,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base {
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
Element
>
;
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
Element
>
;
using
SmemCopyAtomOaccum
=
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
;
using
SmemCopyAtomOaccum
=
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
;
static
constexpr
int
kSmemQCount
=
size
(
SmemLayoutQ
{});
static
constexpr
int
kSmemQSize
=
size
(
SmemLayoutQ
{})
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemKVSize
=
size
(
SmemLayoutKV
{})
*
2
*
sizeof
(
Element
);
static
constexpr
int
kSmemQSize
=
kSmemQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
Share_Q_K_smem
?
std
::
max
(
kSmemQSize
,
kSmemKVSize
)
:
kSmemQSize
+
kSmemKVSize
;
static
constexpr
int
kSmemSize
=
Share_Q_K_smem
?
std
::
max
(
kSmemQSize
,
kSmemKVSize
)
:
kSmemQSize
+
kSmemKVSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
...
@@ -139,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -139,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtom
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
static
constexpr
int
kGmemThreadsPerRowP
=
kBlockN
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRowP
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRowP"
);
using
GmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRowP
>
,
Int
<
kGmemThreadsPerRowP
>>
,
Stride
<
Int
<
kGmemThreadsPerRowP
>
,
_1
>>
;
using
GmemTiledCopyP
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemLayoutAtomOaccum
=
std
::
conditional_t
<
using
GmemLayoutAtomOaccum
=
std
::
conditional_t
<
kBlockKSmem
==
32
,
kBlockKSmem
==
32
,
...
@@ -285,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -285,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base {
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kHeadDim
>
{})));
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kHeadDim
>
{})));
using
SmemCopyAtomdQ
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomdQ
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
static
constexpr
int
kSmemQdOCount
=
size
(
SmemLayoutQdO
{})
*
(
No_double_buffer
?
2
:
3
);
// Double buffer for sQ
// Double buffer for sQ
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemQdOSize
=
size
(
SmemLayoutQdO
{})
*
(
No_double_buffer
?
2
:
3
)
*
sizeof
(
Element
);
static
constexpr
int
kSmemdSCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemKVSize
=
size
(
SmemLayoutKV
{})
*
2
*
sizeof
(
Element
);
static
constexpr
int
kSmemPCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemdSSize
=
size
(
SmemLayoutPdS
{})
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQCount
=
size
(
SmemLayoutdQ
{});
static
constexpr
int
kSmemPSize
=
size
(
SmemLayoutPdS
{})
*
sizeof
(
Element
);
static
constexpr
int
kSmemQdOSize
=
kSmemQdOCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
size
(
SmemLayoutdQ
{})
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdSSize
=
kSmemdSCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemPSize
=
kSmemPCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
kSmemdQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
kSmemQdOSize
static
constexpr
int
kSmemSize
=
kSmemQdOSize
+
(
!
Is_V_in_regs
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)
?
kSmemKVSize
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)
...
...
flash_attn/flash_attn_interface.py
View file @
a7b66ae2
...
@@ -12,7 +12,7 @@ import flash_attn_2_cuda as flash_attn_cuda
...
@@ -12,7 +12,7 @@ import flash_attn_2_cuda as flash_attn_cuda
# isort: on
# isort: on
def
_get_block_size
(
device
,
head_dim
,
is_dropout
,
is_causal
):
def
_get_block_size
_n
(
device
,
head_dim
,
is_dropout
,
is_causal
):
# This should match the block sizes in the CUDA kernel
# This should match the block sizes in the CUDA kernel
assert
head_dim
<=
256
assert
head_dim
<=
256
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
device
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
device
)
...
@@ -20,27 +20,27 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
...
@@ -20,27 +20,27 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
is_sm80
=
major
==
8
and
minor
==
0
is_sm80
=
major
==
8
and
minor
==
0
is_sm90
=
major
==
9
and
minor
==
0
is_sm90
=
major
==
9
and
minor
==
0
if
head_dim
<=
32
:
if
head_dim
<=
32
:
return
128
,
128
return
128
if
head_dim
<=
64
:
if
head_dim
<=
64
:
return
(
128
,
128
)
if
not
is_dropout
else
(
128
,
64
)
return
128
if
not
is_dropout
else
64
elif
head_dim
<=
96
:
elif
head_dim
<=
96
:
return
(
64
,
64
)
if
(
is_sm8x
and
is_causal
)
else
(
128
,
64
)
return
64
elif
head_dim
<=
128
:
elif
head_dim
<=
128
:
if
is_sm8x
:
if
is_sm8x
:
return
(
64
,
64
)
if
(
not
is_dropout
and
is_causal
)
else
(
128
,
32
)
return
64
if
(
not
is_dropout
and
is_causal
)
else
32
else
:
else
:
return
128
,
(
64
if
not
is_dropout
else
32
)
return
64
if
not
is_dropout
else
32
elif
head_dim
<=
160
:
elif
head_dim
<=
160
:
if
is_sm8x
:
if
is_sm8x
:
return
(
128
,
64
)
if
not
is_causal
else
(
64
,
64
)
return
64
else
:
else
:
return
128
,
32
return
32
elif
head_dim
<=
192
:
elif
head_dim
<=
192
:
return
(
128
,
64
)
if
not
is_dropout
else
(
64
,
64
)
return
64
elif
head_dim
<=
224
:
elif
head_dim
<=
224
:
return
(
128
,
64
)
if
(
is_sm80
or
is_sm90
)
else
(
64
,
64
)
return
64
elif
head_dim
<=
256
:
elif
head_dim
<=
256
:
return
(
128
,
64
)
if
is_sm80
else
(
64
,
64
)
return
64
def
_flash_attn_forward
(
def
_flash_attn_forward
(
...
...
tests/test_flash_attn.py
View file @
a7b66ae2
...
@@ -14,7 +14,7 @@ from flash_attn import (
...
@@ -14,7 +14,7 @@ from flash_attn import (
flash_attn_with_kvcache
,
flash_attn_with_kvcache
,
)
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
from
flash_attn.flash_attn_interface
import
_get_block_size
_n
from
flash_attn.layers.rotary
import
apply_rotary_emb
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
MAX_HEADDIM_SM8x
=
192
...
@@ -406,29 +406,7 @@ def convert_flash_attn_S_to_softmax(
...
@@ -406,29 +406,7 @@ def convert_flash_attn_S_to_softmax(
if
causal
:
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
S_converted
=
S
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
nblocks_n
=
(
seqlen_k_rounded
+
blocksize_n
-
1
)
//
blocksize_n
nblocks_m
=
(
seqlen_q_rounded
+
blocksize_m
-
1
)
//
blocksize_m
mmas_n
=
(
blocksize_n
+
16
-
1
)
//
16
S_flat
=
rearrange
(
S
,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)"
,
blocksize_m
=
blocksize_m
,
blocksize_n
=
blocksize_n
,
)
S_converted
=
rearrange
(
S_flat
,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)"
,
mmas_n
=
mmas_n
,
warps_n
=
warps_n
,
eight
=
8
,
c0
=
2
,
c1
=
2
,
c2
=
2
,
four
=
4
,
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_q
,
...
@@ -443,7 +421,7 @@ def convert_flash_attn_S_to_softmax(
...
@@ -443,7 +421,7 @@ def convert_flash_attn_S_to_softmax(
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
value
=
True
,
)
)
S_converted
.
masked_fill
_
(
local_mask
,
0.0
)
S_converted
=
S_converted
.
masked_fill
(
local_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
# and some of those values aren't overwritten.
...
@@ -504,7 +482,7 @@ def normalize_flash_attn_S(
...
@@ -504,7 +482,7 @@ def normalize_flash_attn_S(
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
.
to
(
dtype
=
scores
.
dtype
)
scores
=
scores
+
attn_bias
.
to
(
dtype
=
scores
.
dtype
)
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
block_size_n
=
_get_block_size
_n
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment