Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
10dad612
Commit
10dad612
authored
Jan 14, 2024
by
Tri Dao
Browse files
apply_dropout now takes tensor of rowcol layout
parent
d9cbcfb4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
19 deletions
+33
-19
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+2
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+13
-13
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+3
-1
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+14
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-1
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
10dad612
...
@@ -886,9 +886,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -886,9 +886,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores
_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
scores
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
);
}
}
...
@@ -1446,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1446,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores
_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
scores
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
);
}
}
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
10dad612
...
@@ -399,27 +399,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -399,27 +399,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Convert scores from fp32 to fp16/bf16
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
tOrP
drop
=
make_tensor
(
acc_s_f16
.
data
(),
tO
rP
.
layout
());
Tensor
acc_s_f16_
drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP
drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
acc_s_f16_
drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tO
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
flash
::
apply_dropout
(
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
// if (cute::thread0()) { print(tOrP); }
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print(scores); }
...
@@ -484,26 +484,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -484,26 +484,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
tOrP
drop
=
make_tensor
(
acc_s_f16
.
data
(),
tO
rP
.
layout
());
Tensor
acc_s_f16_
drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP
drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
acc_s_f16_
drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tO
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
flash
::
apply_dropout
(
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
}
...
...
csrc/flash_attn/src/softmax.h
View file @
10dad612
...
@@ -213,10 +213,12 @@ inline __device__ void apply_mask_causal_w_idx(
...
@@ -213,10 +213,12 @@ inline __device__ void apply_mask_causal_w_idx(
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
uint8_t
p_dropout_in_uint8_t
,
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
_
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
int
block_row_start
,
int
block_col_start
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
int
block_row_stride
)
{
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
// tensor has shape (8, MMA_M, MMA_N / 2)
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
...
...
csrc/flash_attn/src/utils.h
View file @
10dad612
...
@@ -211,6 +211,20 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
...
@@ -211,6 +211,20 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
template
<
typename
Layout
>
inline
__device__
auto
convert_layout_rowcol_dropout
(
Layout
rowcol_layout
)
{
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
2
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
1
,
0
>
(
l
),
get
<
0
,
0
>
(
l
),
get
<
1
,
1
,
0
>
(
l
)),
get
<
0
,
1
>
(
l
),
get
<
1
,
1
,
1
>
(
l
));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
inline
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
inline
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
using
From_type
=
typename
Engine
::
value_type
;
...
...
tests/test_flash_attn.py
View file @
10dad612
...
@@ -545,7 +545,7 @@ def get_dropout_fraction(
...
@@ -545,7 +545,7 @@ def get_dropout_fraction(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [
Tru
e])
# @pytest.mark.parametrize("deterministic", [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [False])
# @pytest.mark.parametrize("alibi", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment