Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
45567a25
Commit
45567a25
authored
Apr 15, 2023
by
Kirthi Shankar Sivamani
Browse files
only 1 thread writes to global mem in fprop
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
a0997bc7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+6
-2
No files found.
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
45567a25
...
@@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
...
@@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
const
int
bidb
=
blockIdx
.
x
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
// The block index for the head.
const
int
bidh
=
blockIdx
.
y
;
const
int
bidh
=
blockIdx
.
y
;
// The block index.
const
int
bidx
=
gridDim
.
x
*
bidh
+
bidb
;
// The thread index.
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
...
@@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
...
@@ -678,8 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
if
(
bidx
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
0
,
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
0
,
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
);
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
STEPS
=
(
params
.
seqlen_q
+
M
-
1
)
/
M
;
const
int
STEPS
=
(
params
.
seqlen_q
+
M
-
1
)
/
M
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment