Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
26d7d92f
Commit
26d7d92f
authored
Sep 03, 2023
by
Tri Dao
Browse files
Fix splitKV combine function when local LSEs are all -inf
parent
de2949f3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
1 deletion
+3
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+3
-1
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
26d7d92f
...
@@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
for
(
int
l
=
1
;
l
<
kNLsePerThread
;
++
l
)
{
lse_sum
+=
expf
(
lse_accum
(
l
)
-
lse_max
);
}
for
(
int
l
=
1
;
l
<
kNLsePerThread
;
++
l
)
{
lse_sum
+=
expf
(
lse_accum
(
l
)
-
lse_max
);
}
SumOp
<
float
>
sum_op
;
SumOp
<
float
>
sum_op
;
lse_sum
=
Allreduce
<
kRowsPerLoadTranspose
>::
run
(
lse_sum
,
sum_op
);
lse_sum
=
Allreduce
<
kRowsPerLoadTranspose
>::
run
(
lse_sum
,
sum_op
);
ElementAccum
lse_logsum
=
logf
(
lse_sum
)
+
lse_max
;
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
// Store the scales exp(lse - lse_logsum) in shared memory.
// Store the scales exp(lse - lse_logsum) in shared memory.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment