Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5953c4f5
Commit
5953c4f5
authored
Sep 03, 2023
by
Tri Dao
Browse files
Remove unused sdPsum in dot_do_o function
parent
b28ec236
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
16 deletions
+10
-16
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+6
-16
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+4
-0
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5953c4f5
...
@@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
...
@@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_ROW
,
typename
Engine0
,
typename
Layout0
,
template
<
int
THREADS_PER_ROW
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
>
inline
__device__
void
dot_do_o
(
Tensor
<
Engine0
,
Layout0
>
const
&
do_
,
Tensor
<
Engine0
,
Layout0
>
const
&
o
,
inline
__device__
void
dot_do_o
(
Tensor
<
Engine0
,
Layout0
>
const
&
do_
,
Tensor
<
Engine0
,
Layout0
>
const
&
o
,
Tensor
<
Engine1
,
Layout1
>
&
dP_sum
,
Tensor
<
Engine2
,
Layout2
>
&
sdPsum
,
Tensor
<
Engine1
,
Layout1
>
&
dP_sum
,
const
int
gdP_col_stride
,
const
float
scale
)
{
const
int
gdP_col_stride
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
3
,
"Only support 3D Tensor"
);
static_assert
(
Layout0
::
rank
==
3
,
"Only support 3D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
do_
.
layout
()
==
o
.
layout
());
CUTE_STATIC_ASSERT_V
(
do_
.
layout
()
==
o
.
layout
());
...
@@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
...
@@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
dP_sum_cur
=
flash
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
dP_sum_cur
,
sum_op
)
*
scale
;
dP_sum_cur
=
flash
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
dP_sum_cur
,
sum_op
)
*
scale
;
if
(
threadIdx
.
x
%
THREADS_PER_ROW
==
0
)
{
if
(
threadIdx
.
x
%
THREADS_PER_ROW
==
0
)
{
dP_sum
(
mi
*
gdP_col_stride
+
threadIdx
.
x
/
THREADS_PER_ROW
)
=
dP_sum_cur
;
dP_sum
(
mi
*
gdP_col_stride
+
threadIdx
.
x
/
THREADS_PER_ROW
)
=
dP_sum_cur
;
// recast<float>(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum;
}
}
}
}
}
}
...
@@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
...
@@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
// so that (dP - dP_sum) is on the same scale.
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
dP_sum
,
dP_sum
,
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
dP_sum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
if
(
Clear_dQaccum
)
{
if
(
Clear_dQaccum
)
{
Tensor
zero
=
make_fragment_like
(
tdQgdQaccum
);
Tensor
zero
=
make_fragment_like
(
tdQgdQaccum
);
...
@@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
sPtNoSwizzle
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposedNoSwizzle
{});
Tensor
sPtNoSwizzle
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutPdStransposedNoSwizzle
{});
// sP and sdQ share the same memory so be careful
// sP and sdQ share the same memory so be careful
Tensor
sdQ
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutdQ
{});
Tensor
sdQ
=
make_tensor
(
sP
.
data
(),
typename
Kernel_traits
::
SmemLayoutdQ
{});
Tensor
sdPsum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
float2
*>
((
sP
.
data
()
+
cute
::
max
(
size
(
sP
),
size
(
sdQ
))).
get
())),
Shape
<
Int
<
Kernel_traits
::
kSmemdPsumCount
/
2
>>
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
@@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if
(
Is_first
)
{
if
(
Is_first
)
{
cute
::
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
}
}
...
@@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
lse
(
mi
)
=
gLSE
(
get
<
0
>
(
taccScS_row
(
mi
)));
}
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
lse
(
mi
)
=
gLSE
(
get
<
0
>
(
taccScS_row
(
mi
)));
}
gdPsum
.
data
()
=
gdPsum
.
data
()
+
(
-
int
(
kBlockM
));
gdPsum
.
data
()
=
gdPsum
.
data
()
+
(
-
int
(
kBlockM
));
// if (!Is_first && tidx < kBlockM / 2) {
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
// if (!Is_first && tidx < kBlockM) {
// recast<float>(sdPsum)(tidx) = gdPsum(tidx);
// }
}
}
if
(
!
Is_last
)
{
if
(
!
Is_last
)
{
...
@@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
Is_first
&&
m_block
>
m_block_min
)
{
if
(
Is_first
&&
m_block
>
m_block_min
)
{
cute
::
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
sdPsum
,
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
gdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
}
}
...
@@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
dP_sum
=
make_fragment_like
(
lse
);
Tensor
dP_sum
=
make_fragment_like
(
lse
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
cute
::
copy
(
tdOrdO
,
tdOsdO
);
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
dot_do_o
<
Kernel_traits
::
kGmemThreadsPerRow
>
(
tdOrdO
,
tdOrO
,
sdPsum
,
sdPsum
,
tdOrdO
,
tdOrO
,
sdPsum
,
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
Kernel_traits
::
kNThreads
/
(
Kernel_traits
::
kGmemThreadsPerRow
),
params
.
p_dropout
);
);
__syncthreads
();
__syncthreads
();
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
5953c4f5
...
@@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base {
static
constexpr
int
kSmemdSSize
=
kSmemdSCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdSSize
=
kSmemdSCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemPSize
=
kSmemPCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemPSize
=
kSmemPCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
kSmemdQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
kSmemdQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
kSmemQdOSize
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)
:
std
::
max
(
kSmemKVSize
,
kSmemKVSize
/
2
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)));
static
constexpr
int
kSmemSize1colblock
=
kSmemQdOSize
static
constexpr
int
kSmemSize1colblock
=
kSmemQdOSize
+
(
!
Is_V_in_regs
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
kSmemPSize
?
kSmemKVSize
+
kSmemdSSize
+
kSmemPSize
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment