Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6f706eff
Commit
6f706eff
authored
Jan 15, 2024
by
Tri Dao
Browse files
Make Softmax an object
parent
4ea866ca
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
68 deletions
+66
-68
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+12
-43
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+54
-25
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
6f706eff
...
...
@@ -180,10 +180,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor
scores_max
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
2
*
size
<
1
>
(
acc_o
)
>>
{});
Tensor
scores_sum
=
make_fragment_like
(
scores_max
);
//
// PREDICATES
//
...
...
@@ -267,6 +263,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
flash
::
Softmax
<
2
*
size
<
1
>
(
acc_o
)
>
softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
...
...
@@ -357,8 +355,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
?
flash
::
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
...
...
@@ -435,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
);
}
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
...
...
@@ -461,20 +459,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
lse
=
make_fragment_like
(
scores_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
params
.
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
Is_dropout
>(
acc_o
,
params
.
scale_softmax
,
params
.
rp_dropout
);
// Convert acc_o from fp32 to fp16/bf16
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
...
...
@@ -685,11 +670,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor
scores_max
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
2
*
size
<
1
>
(
acc_o
)
>>
{});
Tensor
scores_sum
=
make_fragment_like
(
scores_max
);
//
// PREDICATES
//
...
...
@@ -862,6 +842,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
flash
::
Softmax
<
2
*
size
<
1
>
(
acc_o
)
>
softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
...
...
@@ -939,8 +921,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We have key_padding_mask so we'll need to Check_inf
masking_step
==
0
?
flash
::
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
...
...
@@ -1002,7 +984,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
params
.
window_size_left
,
params
.
window_size_right
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
...
@@ -1014,21 +996,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
lse
=
make_fragment_like
(
scores_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
inv_sum
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
params
.
scale_softmax
);
// if (cute::thread0()) { print(lse); }
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
sOaccum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementO
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
...
...
csrc/flash_attn/src/softmax.h
View file @
6f706eff
...
...
@@ -117,35 +117,64 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
template
<
int
kNRows
>
struct
Softmax
{
using
TensorT
=
decltype
(
make_tensor
<
float
>
(
Shape
<
Int
<
kNRows
>>
{}));
TensorT
row_max
,
row_sum
;
inline
__device__
Softmax
()
{};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
acc_s
,
Tensor1
&
acc_o
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
scores
_max
);
flash
::
scale_apply_exp2
(
scores
,
scores
_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
scores
_sum
);
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row
_max
);
flash
::
scale_apply_exp2
(
scores
,
row
_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
row
_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores
_max
);
cute
::
copy
(
scores
_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores
_max
);
Tensor
scores_max_prev
=
make_fragment_like
(
row
_max
);
cute
::
copy
(
row
_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row
_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores
_max
);
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
size
(
row
_max
);
++
mi
)
{
float
scores_max_cur
=
!
Check_inf
?
scores
_max
(
mi
)
:
(
scores
_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
scores
_max
(
mi
));
?
row
_max
(
mi
)
:
(
row
_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row
_max
(
mi
));
float
scores_scale
=
exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
scores
_sum
(
mi
)
*=
scores_scale
;
row
_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
flash
::
scale_apply_exp2
(
scores
,
scores
_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
scores
_sum
);
flash
::
scale_apply_exp2
(
scores
,
row
_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
row
_sum
);
flash
::
reduce_sum
(
scores
,
scores_sum_cur
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_sum
);
++
mi
)
{
scores_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
for
(
int
mi
=
0
;
mi
<
size
(
row_sum
);
++
mi
)
{
row_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
}
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
>
inline
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
return
lse
;
};
};
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment