Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a5a8806d
Commit
a5a8806d
authored
Oct 23, 2022
by
Tri Dao
Browse files
Split bwd on the seqlen_q dimension
parent
871db479
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
205 additions
and
120 deletions
+205
-120
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+24
-3
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+5
-1
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+42
-2
csrc/flash_attn/src/fmha/kernel_traits.h
csrc/flash_attn/src/fmha/kernel_traits.h
+6
-0
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+62
-40
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+52
-30
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+0
-35
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+9
-6
tests/test_flash_attn.py
tests/test_flash_attn.py
+5
-3
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
a5a8806d
...
...
@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
int
blocksize_c
=
(
head_size
==
128
&&
(
!
is_sm80
))
?
128
:
256
;
int
blocksize_c
=
head_size
==
128
?
128
:
256
;
// Need to round max_seqlen_k to multiples of blocksize_c
int
max_seqlen_k
=
((
max_seqlen_k_
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
if
(
max_seqlen_k_
<=
128
)
{
...
...
@@ -332,6 +332,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
float
softmax_scale
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
int
num_splits
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
...
@@ -447,7 +448,22 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout
,
softmax_scale
,
is_causal
,
/*num_splits=*/
1
);
num_splits
);
launch
(
params
,
stream
,
/*configure=*/
true
);
at
::
Tensor
dk_accum
,
dv_accum
;
if
(
params
.
num_splits
>
1
)
{
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// params.dk_accum_ptr = dk_accum.data_ptr();
// params.dv_accum_ptr = dv_accum.data_ptr();
dk
.
zero_
();
dv
.
zero_
();
}
else
{
// params.dk_accum_ptr = nullptr;
// params.dv_accum_ptr = nullptr;
}
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
...
@@ -461,7 +477,12 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
launch
(
params
,
stream
);
launch
(
params
,
stream
,
/*configure=*/
false
);
// if (params.num_splits > 1) {
// dk.copy_(dk_accum);
// dv.copy_(dv_accum);
// }
return
{
dq
,
dk
,
dv
,
softmax_d
};
}
...
...
csrc/flash_attn/src/fmha.h
View file @
a5a8806d
...
...
@@ -140,6 +140,10 @@ struct FMHA_dgrad_params : public FMHA_fprop_params {
void
*
__restrict__
dk_ptr
;
void
*
__restrict__
dv_ptr
;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
// void *__restrict__ dk_accum_ptr;
// void *__restrict__ dv_accum_ptr;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
...
...
@@ -193,7 +197,7 @@ struct Launch_params{
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_sm80
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
);
...
...
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
a5a8806d
...
...
@@ -28,6 +28,9 @@
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <fmha/utils.h>
namespace
fmha
{
...
...
@@ -41,7 +44,8 @@ template<
// The number of rows of Q, K or V loaded by this tile.
int
ROWS_
,
// The number of columns.
int
COLS
int
COLS
,
int
BYTES_PER_LDGS_
=
16
>
struct
Gmem_tile_qkv
{
...
...
@@ -49,7 +53,7 @@ struct Gmem_tile_qkv {
static
constexpr
int
BYTES_PER_ELEMENT
=
BITS_PER_ELEMENT
/
8
;
// The size of each LDG.
static
constexpr
int
BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
BYTES_PER_LDGS_
;
// The size of a row in bytes.
static
constexpr
int
BYTES_PER_ROW
=
COLS
*
BITS_PER_ELEMENT
/
8
;
...
...
@@ -130,6 +134,42 @@ struct Gmem_tile_qkv {
}
}
template
<
typename
elem_type
>
inline
__device__
void
atomic_add
(
const
uint4
(
&
data
)[
LDGS
])
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
using
elem2_type
=
typename
std
::
conditional
<
std
::
is_same
<
elem_type
,
__half
>::
value
,
__half2
,
__nv_bfloat162
>::
type
;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type
*
ptr_
=
reinterpret_cast
<
elem2_type
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
atomicAdd
(
ptr_
+
jj
,
reinterpret_cast
<
const
elem2_type
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
}
}
}
}
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
inline
__device__
void
atomic_add_float
(
const
uint4
(
&
data
)[
LDGS
])
{
static_assert
(
BYTES_PER_ELEMENT
==
4
);
// Only support fp32
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float
*
ptr_
=
reinterpret_cast
<
float
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
const
float2
data_f
=
fmha
::
half2_unpack
<
__half
>
(
reinterpret_cast
<
const
uint32_t
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
atomicAdd
(
ptr_
+
jj
*
2
,
data_f
.
x
);
atomicAdd
(
ptr_
+
jj
*
2
+
1
,
data_f
.
y
);
}
}
}
}
inline
__device__
void
move
(
const
int
steps
=
1
)
{
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
...
...
csrc/flash_attn/src/fmha/kernel_traits.h
View file @
a5a8806d
...
...
@@ -76,6 +76,12 @@ struct FMHA_kernel_traits {
using
Gmem_tile_do
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
// // The global memory tile to store the accumulated dK and dV
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
// // be issue any load or store of size 32 bytes.
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
// The global memory tile to store the softmax sum.
using
Gmem_softmax_sum
=
fmha
::
Gmem_summary_stats
<
Cta_tile_p
>
;
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
a5a8806d
...
...
@@ -6,13 +6,45 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// Moreover, more than 1 split incurs extra cost of zeroing out dk/dv and doing atomic add
// instead of just writing.
// So for num_splits > 1, we divide the efficiency by some factor (e.g. 1.25, depending on seqlen)
// to account for this. Moreover, more splits means atomic add will be slower.
int
num_splits_heuristic_bwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
max_splits
,
int
seqlen
,
bool
is_causal
)
{
float
max_efficiency
=
0.
f
;
int
best_num_splits
=
1
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
float
discount_factor
=
1.
f
+
512.0
/
seqlen
;
// 1.25 for seqlen 2k, 1.125 for 4k.
discount_factor
*=
is_causal
?
1.1
:
1.
f
;
// causal makes it even slower.
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
float
n_waves
=
float
(
batch_nheads
*
num_splits
)
/
(
num_SMs
*
ctas_per_sm
);
float
eff_raw
=
n_waves
/
ceil
(
n_waves
);
// Heuristic: each increase in num_splits results in 6% slowdown, up to maybe 8 splits.
float
eff
=
num_splits
==
1
?
eff_raw
:
(
eff_raw
-
0.07
*
std
::
min
(
num_splits
-
2
,
6
))
/
discount_factor
;
// printf("num_splits = %d, eff_raw = %f, eff = %f\n", num_splits, eff_raw, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
best_num_splits
=
num_splits
;
}
efficiency
.
push_back
(
eff
);
}
// printf("num_splits chosen = %d\n", best_num_splits);
return
best_num_splits
;
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
FMHA_dgrad_params
params
)
{
fmha
::
compute_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_dgrad_fp16_sm80_loop_
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80_loop_
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
...
...
@@ -46,41 +78,58 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
dim3
grid
(
params
.
b
,
params
.
h
);
// Automatically set num_splits to maximize occupancy
if
(
params
.
num_splits
<=
0
)
{
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
);
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
// We don't want more than 10 splits due to numerical error.
// Numerical error on dk/dv scales as sqrt(num_splits).
params
.
num_splits
=
num_splits_heuristic_bwd
(
params
.
b
*
params
.
h
,
dprops
->
multiProcessorCount
,
ctas_per_sm
,
/*max_splits=*/
std
::
min
(
10
,
(
params
.
seqlen_q
+
M
-
1
/
M
)),
params
.
seqlen_k
,
params
.
is_causal
);
}
if
(
configure
)
return
;
dim3
grid
(
params
.
b
,
params
.
h
,
params
.
num_splits
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
}
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
==
16
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
else
if
(
params
.
d
==
32
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
else
if
(
params
.
d
==
64
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
...
...
@@ -88,45 +137,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stre
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
}
else
if
(
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
// if (params.d == 64) {
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
a5a8806d
...
...
@@ -135,6 +135,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
// How many steps to jump per iteration, which is the same as params.num_splits.
const
int
step_stride
=
gridDim
.
z
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
// if( binfo.stop_early() ) return;
if
(
binfo
.
stop_early
(
loop_step_idx
*
Cta_tile_p
::
N
)
)
return
;
...
...
@@ -184,18 +187,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const
int
begin_mod_z
=
begin
%
gridDim
.
z
;
begin
=
begin_mod_z
<=
blockIdx
.
z
?
begin
-
begin_mod_z
:
begin
+
gridDim
.
z
-
begin_mod_z
;
const
int
steps
=
(
params
.
seqlen_q
+
Cta_tile_p
::
M
-
1
)
/
Cta_tile_p
::
M
-
begin
;
// Wind gmem tiles to the correct position.
gmem_q
.
move
(
begin
);
gmem_do
.
move
(
begin
);
gmem_o
.
move
(
begin
);
gmem_dq
.
move
(
begin
);
gmem_dq_tmp
.
move
(
begin
);
gmem_q
.
move
(
begin
+
blockIdx
.
z
);
gmem_do
.
move
(
begin
+
blockIdx
.
z
);
gmem_o
.
move
(
begin
+
blockIdx
.
z
);
gmem_dq
.
move
(
begin
+
blockIdx
.
z
);
gmem_dq_tmp
.
move
(
begin
+
blockIdx
.
z
);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse
.
move
(
begin
);
gmem_softmax_d
.
move
(
begin
);
gmem_softmax_lse
.
move
(
begin
+
blockIdx
.
z
);
gmem_softmax_d
.
move
(
begin
+
blockIdx
.
z
);
if
(
!
Is_first
)
{
gmem_k
.
move
(
loop_step_idx
);
...
...
@@ -215,7 +224,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
float
p_lse
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
if
(
!
Is_first
)
{
__syncthreads
();
}
// Commit the data for Q, dO, and V to shared memory.
...
...
@@ -265,7 +273,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
float
dp_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
...
...
@@ -301,9 +308,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
const
int
loop
=
(
begin
+
l
)
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen_q
)
for
(
int
l
=
blockIdx
.
z
;
l
<
steps
;
l
+=
step_stride
)
{
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen_q
)
break
;
// Load the fragments for V.
...
...
@@ -352,9 +358,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_s
.
store
(
frag_p
);
// Trigger the load for the next Q values.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
move
(
step_stride
);
gmem_q
.
load
();
}
...
...
@@ -427,12 +433,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_kt
.
load
(
frag_kt
[
0
],
0
);
// Trigger the load for the next dO values.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
smem_do
.
move_to_next_write_buffer
();
gmem_do
.
move
();
gmem_do
.
move
(
step_stride
);
gmem_do
.
load
();
if
(
Is_first
)
{
gmem_o
.
move
();
gmem_o
.
move
(
step_stride
);
gmem_o
.
load
();
}
}
...
...
@@ -443,7 +449,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_dp
.
store
(
frag_p
);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// gmem_s.move(
step_stride
);
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dq
[
Mma_tile_dq
::
MMAS_M
][
Mma_tile_dq
::
MMAS_N
];
...
...
@@ -520,7 +526,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
}
...
...
@@ -529,15 +535,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gmem_do
.
commit
(
smem_do
);
gmem_softmax_d
.
move
(
step_stride
);
if
(
Is_first
)
{
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
elem_type
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
);
}
gmem_softmax_lse
.
move
(
step_stride
);
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
}
typename
Smem_tile_st
::
Fragment
frag_dpt
[
Mma_tile_dkv
::
MMAS_K
][
Mma_tile_dkv
::
MMAS_M
];
...
...
@@ -567,9 +574,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Make sure dQ is in shared memory.
__syncthreads
();
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
// Load from shared memory.
...
...
@@ -590,20 +596,20 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Output the values.
gmem_dq
.
template
store
<
elem_type
>(
dq_out
,
0
);
// Move to the next part of the output.
gmem_dq
.
move
();
gmem_dq
.
move
(
step_stride
);
}
else
{
// Output the values.
gmem_dq_tmp
.
store
(
dq_out
,
0
);
}
// Move to the next part of the output.
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_dq_tmp
.
move
();
}
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_dq_tmp
.
move
(
step_stride
);
}
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_qt
.
move_to_next_read_buffer
();
...
...
@@ -652,18 +658,34 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
// gmem_dv_accum.move(loop_step_idx);
}
if
(
gridDim
.
z
==
1
)
{
gmem_dv
.
store
(
dv_out
);
}
else
{
gmem_dv
.
template
atomic_add
<
elem_type
>(
dv_out
);
// gmem_dv_accum.atomic_add_float(dv_out);
}
gmem_dv
.
store
(
dv_out
);
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
// gmem_dk_accum.move(loop_step_idx);
}
if
(
gridDim
.
z
==
1
)
{
gmem_dk
.
store
(
dk_out
);
}
else
{
gmem_dk
.
template
atomic_add
<
elem_type
>(
dk_out
);
// gmem_dk_accum.atomic_add_float(dk_out);
}
gmem_dk
.
store
(
dk_out
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
a5a8806d
...
...
@@ -162,40 +162,5 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
});
}
\ No newline at end of file
flash_attn/flash_attn_interface.py
View file @
a5a8806d
...
...
@@ -7,10 +7,7 @@ import flash_attn_cuda
def
_get_block_size
(
device
,
head_dim
,
is_dropout
):
assert
head_dim
in
[
16
,
32
,
64
,
128
]
if
head_dim
in
[
16
,
32
,
64
]:
return
256
elif
head_dim
==
128
:
return
256
if
(
torch
.
cuda
.
get_device_capability
(
device
)
==
(
8
,
0
))
else
128
return
256
if
head_dim
in
[
16
,
32
,
64
]
else
128
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
...
...
@@ -32,11 +29,17 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
num_splits
=
0
,
generator
=
None
):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. Setting this too large (e.g. > 10) could make
numerical error of dK and dV larger (scaling as sqrt(num_splits)).
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
softmax_d
=
flash_attn_cuda
.
bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
generator
)
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
num_splits
,
generator
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dq
,
dk
,
dv
,
softmax_d
...
...
tests/test_flash_attn.py
View file @
a5a8806d
...
...
@@ -356,7 +356,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
# Set smaller batch size so it would trigger num_splits > 1
batch_size
=
8
nheads
=
4
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
*
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -418,10 +419,11 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if
dropout_p
==
0.0
:
assert
dropout_mask
.
all
()
else
:
assert
0.9
9
<=
dropout_fraction
/
dropout_p
<=
1.0
1
assert
0.9
8
<=
dropout_fraction
/
dropout_p
<=
1.0
2
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
4
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment