Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c422fee3
Commit
c422fee3
authored
Oct 24, 2022
by
Tri Dao
Browse files
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
parent
46fd2a20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
13 deletions
+4
-13
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+4
-13
No files found.
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
c422fee3
...
...
@@ -554,20 +554,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
}
// When d = 16, O only has 16 x 16 = 256 elements, and each of the 128 threads wants
// to write 4 elements, so only half of the thread should deal with O.
bool
o_rows_are_valid
=
(
Kernel_traits
::
THREADS
<=
Gmem_tile_o
::
THREADS_PER_ROW
*
Gmem_tile_o
::
ROWS
)
||
(
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
<
Gmem_tile_o
::
ROWS
);
if
(
o_rows_are_valid
)
{
softmax
.
reduce_max_after_sync_
(
p_max_o
,
rows
);
}
softmax
.
reduce_max_after_sync_
(
p_max_o
,
rows
);
static_assert
(
Mma_tile_o
::
MMAS_M
==
1
);
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
p_max_o
[
jj
][
0
]
*=
params
.
scale_bmm1f
;
}
float
p_prev_scale_o
[
Gmem_tile_o
::
STGS_PER_LOOP
];
if
(
(
!
Is_first
)
&&
o_rows_are_valid
)
{
if
(
!
Is_first
)
{
smem_softmax_lse
.
load
(
p_prev_scale_o
,
rows
);
}
// if (!Is_first) {
...
...
@@ -586,9 +579,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
static_assert
(
Mma_tile_o
::
MMAS_M
==
1
);
float
p_sum_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
if
(
o_rows_are_valid
)
{
softmax
.
reduce_sum_after_sync_
(
p_sum_o
,
rows
);
}
softmax
.
reduce_sum_after_sync_
(
p_sum_o
,
rows
);
if
(
!
Is_first
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
p_prev_scale_o
[
jj
]
=
expf
(
p_prev_scale_o
[
jj
]
-
p_max_o
[
jj
][
0
]);
...
...
@@ -609,7 +600,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if
(
(
tidx
%
Gmem_tile_o
::
THREADS_PER_ROW
==
0
)
&&
o_rows_are_valid
)
{
if
(
tidx
%
Gmem_tile_o
::
THREADS_PER_ROW
==
0
)
{
gmem_softmax_lse
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
]
>
(
p_sum_log
[
jj
]),
rows
[
jj
]);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment