Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d380e87f
Commit
d380e87f
authored
Jun 04, 2022
by
Tri Dao
Browse files
Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
parent
b17c6fe2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
63 deletions
+76
-63
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+26
-3
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+30
-55
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+20
-5
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
d380e87f
...
...
@@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dq
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dp_sum
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
Kernel_traits
::
Cta_tile_p
::
N
*
2
);
static_assert
(
smem_size_dq
==
16
*
Kernel_traits
::
Cta_tile_p
::
K
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
static_assert
(
smem_size_dp_sum
==
16
*
4
*
2
);
constexpr
int
smem_size_dq_dk_dv
=
smem_size_q
*
2
+
smem_size_v
*
(
Kernel_traits
::
V_IN_REGS
?
1
:
2
)
+
smem_size_dq
+
smem_size_s
*
2
+
smem_size_dp_sum
;
constexpr
int
smem_size_dq_dk_dv
=
smem_size_q
*
2
+
smem_size_v
*
(
Kernel_traits
::
V_IN_REGS
?
1
:
2
)
+
smem_size_dq
+
smem_size_s
*
2
;
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
bool
is_causal
=
params
.
is_causal
;
...
...
@@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
}
// printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
// if (params.d == 64) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.s >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
}
\ No newline at end of file
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
d380e87f
...
...
@@ -12,16 +12,19 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Smem_dp_sum
,
int
M
>
inline
__device__
void
dot_do_o
(
float
(
&
sum
)[
M
],
const
uint4
(
&
do_
)[
M
],
const
uint4
(
&
o
)[
M
],
Smem_dp_sum
smem
,
const
int
buffer_idx
)
{
template
<
int
ROWS
,
int
THREADS_PER_ROW
,
int
M
,
typename
Gmem_softmax_sum
>
inline
__device__
void
dot_do_o
(
const
uint4
(
&
do_
)[
M
],
const
uint4
(
&
o
)[
M
],
Gmem_softmax_sum
gmem_softmax_d
,
int
tidx
)
{
float
sum
[
M
];
fmha
::
SumOp
<
float
>
sum_op
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
sum
[
mi
]
=
smem
.
reduce_warp
(
fmha
::
hmulsum8
(
do_
[
mi
],
o
[
mi
]));
sum
[
mi
]
=
fmha
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
fmha
::
hmulsum8
(
do_
[
mi
],
o
[
mi
]),
sum_op
);
}
const
int
dp_sum_row
=
tidx
/
THREADS_PER_ROW
;
if
((
dp_sum_row
<
ROWS
)
&&
(
tidx
%
THREADS_PER_ROW
==
0
))
{
gmem_softmax_d
.
store_row
(
reinterpret_cast
<
const
uint32_t
(
&
)[
M
]
>
(
sum
),
dp_sum_row
);
}
static_assert
(
M
==
1
);
smem
.
store
(
sum
[
0
],
buffer_idx
);
// smem.store(sum, buffer_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
using
Smem_dp_sum
=
typename
Kernel_traits
::
Smem_dp_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
...
...
@@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
float
dp_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
if
(
!
Is_first
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
float
dp_sum_regs
[
Gmem_tile_do
::
LDGS
];
Smem_dp_sum
smem_dp_sum
(
reinterpret_cast
<
float
*>
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
*
2
]),
tidx
);
if
(
!
Is_first
)
{
__syncthreads
();
}
// Commit the data for Q, dO, and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
dot_do_o
(
dp_sum_regs
,
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
smem_dp_sum
,
0
);
const
int
dp_sum_row
=
tidx
/
Smem_dp_sum
::
THREADS_PER_ROW
;
if
((
dp_sum_row
<
Smem_dp_sum
::
ROWS
)
&&
(
tidx
%
Smem_dp_sum
::
THREADS_PER_ROW
==
0
))
{
gmem_softmax_d
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Gmem_tile_do
::
LDGS
]
>
(
dp_sum_regs
),
dp_sum_row
);
}
gmem_softmax_d
.
move
();
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
gmem_softmax_d
,
tidx
);
}
// Instead of scaling dP by rp_dropout, we scale V instead
...
...
@@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
float
dp_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
...
...
@@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// __syncthreads();
// }
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha
::
Fragment_accumulator
acc_dp
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
if
(
Is_first
)
{
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dp
);
}
else
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
m
i
=
0
;
m
i
<
Mma_tile_p
::
MMAS_
M
;
++
m
i
)
{
for
(
int
n
i
=
0
;
n
i
<
Mma_tile_p
::
MMAS_
N
;
++
n
i
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
8
;
++
ii
)
{
acc_dp
[
mi
][
ni
].
elt
(
ii
)
=
-
dp_sum
[
mi
*
2
+
((
ii
/
2
)
%
2
)];
}
for
(
int
ii
=
0
;
ii
<
8
;
++
ii
)
{
acc_dp
[
mi
][
ni
].
elt
(
ii
)
=
-
dp_sum
[
mi
*
2
+
((
ii
/
2
)
%
2
)];
}
}
}
...
...
@@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
smem_kt
.
load
(
frag_kt
[
0
],
0
);
if
(
Is_first
)
{
const
int
quad
=
(
tidx
%
Cta_tile_p
::
THREADS_PER_WARP
)
/
4
;
const
int
row
[
2
]
=
{
quad
,
quad
+
8
};
smem_dp_sum
.
load
(
dp_sum
,
row
,
l
%
2
);
}
// Trigger the load for the next dO values.
if
(
l
<
steps
-
1
)
{
smem_do
.
move_to_next_write_buffer
();
...
...
@@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
if
(
Is_first
)
{
softmax
.
subtract_dp_sum
(
dp_sum
);
}
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
softmax
.
pack
(
frag_dp
);
...
...
@@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
if
(
l
<
steps
-
1
)
{
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
// smem_dp_sum.move_to_next_write_buffer();
dot_do_o
(
dp_sum_regs
,
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
smem_dp_sum
,
(
l
+
1
)
%
2
);
const
int
dp_sum_row_1
=
tidx
/
Smem_dp_sum
::
THREADS_PER_ROW
;
if
((
dp_sum_row_1
<
Smem_dp_sum
::
ROWS
)
&&
(
tidx
%
Smem_dp_sum
::
THREADS_PER_ROW
==
0
))
{
gmem_softmax_d
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Gmem_tile_do
::
LDGS
]
>
(
dp_sum_regs
),
dp_sum_row_1
);
}
gmem_softmax_d
.
move
();
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
gmem_softmax_d
,
tidx
);
}
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
if
(
!
Is_first
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
}
typename
Smem_tile_st
::
Fragment
frag_dpt
[
Mma_tile_dkv
::
MMAS_K
][
Mma_tile_dkv
::
MMAS_M
];
...
...
@@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Make sure dQ is in shared memory.
__syncthreads
();
if
(
l
<
steps
-
1
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
// Load from shared memory.
smem_dq
.
template
load
<
/*zero_init=*/
Is_first
>(
dq_out
);
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
d380e87f
...
...
@@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
// if (launch_params.params.d == 64) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.s >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment