Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b32efb1a
"tests/pipelines/vscode:/vscode.git/clone" did not exist on "249d9bc0e76e55eb16c08d0e70b2d6057259c4a0"
Commit
b32efb1a
authored
Feb 20, 2024
by
Tri Dao
Browse files
Don't need to reduce row_sum during online softmax
parent
f45bbb4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+8
-7
No files found.
csrc/flash_attn/src/softmax.h
View file @
b32efb1a
...
@@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
...
@@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
__forceinline__
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
__device__
__forceinline__
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
SumOp
<
float
>
sum_op
;
SumOp
<
float
>
sum_op
;
reduce_
(
tensor
,
sum
,
sum_op
);
thread_reduce_
<
zero_init
>
(
tensor
,
sum
,
sum_op
);
}
}
// Apply the exp to all the elements.
// Apply the exp to all the elements.
...
@@ -133,7 +133,7 @@ struct Softmax {
...
@@ -133,7 +133,7 @@ struct Softmax {
if
(
Is_first
)
{
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
row_sum
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
cute
::
copy
(
row_max
,
scores_max_prev
);
...
@@ -152,15 +152,16 @@ struct Softmax {
...
@@ -152,15 +152,16 @@ struct Softmax {
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
row_sum
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
flash
::
reduce_sum
(
scores
,
scores_sum_cur
);
// We do that reduce at the end when we need to normalize the softmax.
#pragma unroll
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
for
(
int
mi
=
0
;
mi
<
size
(
row_sum
);
++
mi
)
{
row_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
}
}
};
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
>
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
>
__forceinline__
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
__forceinline__
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment