Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
df1418f9
Commit
df1418f9
authored
Jan 14, 2024
by
Tri Dao
Browse files
Move softmax_rescale_o to softmax.h
parent
6777336a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
39 deletions
+39
-39
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+5
-38
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+34
-1
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
df1418f9
...
...
@@ -25,39 +25,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
scores_max
);
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
scores_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
cute
::
copy
(
scores_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_max
);
++
mi
)
{
float
scores_max_cur
=
!
Check_inf
?
scores_max
(
mi
)
:
(
scores_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
scores_max
(
mi
));
float
scores_scale
=
exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
scores_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
scores_sum
);
flash
::
reduce_sum
(
scores
,
scores_sum_cur
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_sum
);
++
mi
)
{
scores_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
...
...
@@ -396,8 +363,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
flash
::
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
...
...
@@ -481,7 +448,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
);
}
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
...
...
@@ -991,8 +958,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We have key_padding_mask so we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
?
flash
::
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
flash
::
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
...
...
csrc/flash_attn/src/softmax.h
View file @
df1418f9
/******************************************************************************
* Copyright (c) 202
3
, Tri Dao.
* Copyright (c) 202
4
, Tri Dao.
******************************************************************************/
#pragma once
...
...
@@ -115,4 +115,37 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
scores_max
);
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
scores_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
cute
::
copy
(
scores_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_max
);
++
mi
)
{
float
scores_max_cur
=
!
Check_inf
?
scores_max
(
mi
)
:
(
scores_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
scores_max
(
mi
));
float
scores_scale
=
exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
scores_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
scores_sum
);
flash
::
reduce_sum
(
scores
,
scores_sum_cur
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_sum
);
++
mi
)
{
scores_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
}
};
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment