Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
46fd2a20
Commit
46fd2a20
authored
Oct 24, 2022
by
Tri Dao
Browse files
Support all head dims that are multiples of 8, up to 128
parent
97e13de2
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
111 additions
and
103 deletions
+111
-103
README.md
README.md
+1
-1
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+5
-5
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+16
-10
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+18
-9
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+10
-5
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+7
-20
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+18
-9
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+7
-20
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+10
-5
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+2
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+17
-17
No files found.
README.md
View file @
46fd2a20
...
@@ -35,7 +35,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
...
@@ -35,7 +35,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports:
FlashAttention currently supports:
1.
Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
1.
Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2.
fp16 and bf16 (bf16 requires Ampere GPUs).
2.
fp16 and bf16 (bf16 requires Ampere GPUs).
3.
Head dimensions
16, 32, 64, 128 (h
ead dim
128
backward requires A100
)
.
3.
Head dimensions
that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). H
ead dim
> 64
backward requires A100.
Our tentative roadmap:
Our tentative roadmap:
1.
[Jun 2022] Make package pip-installable.
1.
[Jun 2022] Make package pip-installable.
...
...
csrc/flash_attn/fmha_api.cpp
View file @
46fd2a20
...
@@ -232,7 +232,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -232,7 +232,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
int
head_size
=
sizes
[
D_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
const
int
total_k
=
k
.
size
(
TOTAL_DIM
);
const
int
total_k
=
k
.
size
(
TOTAL_DIM
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
=
=
128
);
TORCH_CHECK
(
(
head_size
%
8
==
0
)
&&
(
head_size
<
=
128
)
)
;
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads
,
head_size
);
CHECK_SHAPE
(
k
,
total_k
,
num_heads
,
head_size
);
...
@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
int
blocksize_c
=
head_size
==
128
?
128
:
256
;
int
blocksize_c
=
head_size
>
64
?
128
:
256
;
// Need to round max_seqlen_k to multiples of blocksize_c
// Need to round max_seqlen_k to multiples of blocksize_c
int
max_seqlen_k
=
((
max_seqlen_k_
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
int
max_seqlen_k
=
((
max_seqlen_k_
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
if
(
max_seqlen_k_
<=
128
)
{
if
(
max_seqlen_k_
<=
128
)
{
...
@@ -386,8 +386,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -386,8 +386,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const
int
head_size
=
sizes
[
D_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
const
int
total_k
=
k
.
size
(
TOTAL_DIM
);
const
int
total_k
=
k
.
size
(
TOTAL_DIM
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
=
=
128
);
TORCH_CHECK
(
(
head_size
%
8
==
0
)
&&
(
head_size
<
=
128
)
)
;
if
(
head_size
==
128
)
{
// TODO: eventually we should support SM86 and SM70 with d=128 as well
if
(
head_size
>
64
)
{
// TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK
(
is_sm80
);
TORCH_CHECK
(
is_sm80
);
}
}
...
@@ -402,7 +402,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -402,7 +402,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_q
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
CHECK_SHAPE
(
cu_seqlens_k
,
batch_size
+
1
);
int
blocksize_c
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
int
blocksize_c
=
(
head_size
>
64
||
(
is_sm75
&&
head_size
>
32
))
?
128
:
256
;
int
max_seqlen_k
=
((
max_seqlen_k_
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
int
max_seqlen_k
=
((
max_seqlen_k_
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
if
(
max_seqlen_k_
<=
128
)
{
if
(
max_seqlen_k_
<=
128
)
{
max_seqlen_k
=
128
;
max_seqlen_k
=
128
;
...
...
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
46fd2a20
...
@@ -81,11 +81,13 @@ struct Gmem_tile_qkv {
...
@@ -81,11 +81,13 @@ struct Gmem_tile_qkv {
// Ctor.
// Ctor.
template
<
typename
BInfo
>
template
<
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
,
bool
use_seqlen_q
)
const
uint32_t
head_stride_in_elts
,
const
int
headdim
,
const
BInfo
&
binfo
,
const
int
tidx
,
bool
use_seqlen_q
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
use_seqlen_q
?
binfo
.
actual_seqlen_q
:
binfo
.
actual_seqlen_k
)
,
actual_seqlen
(
use_seqlen_q
?
binfo
.
actual_seqlen_q
:
binfo
.
actual_seqlen_k
)
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr_
))
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr_
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
,
col_predicate
((
tidx
%
THREADS_PER_ROW
)
*
(
BYTES_PER_LDG
/
BYTES_PER_ELEMENT
)
<
headdim
)
{
// Compute the position in the sequence (within the CTA for the moment).
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
int
row
=
tidx
/
THREADS_PER_ROW
;
...
@@ -121,7 +123,7 @@ struct Gmem_tile_qkv {
...
@@ -121,7 +123,7 @@ struct Gmem_tile_qkv {
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs
[
ii
]
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
ptrs
[
ii
]
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
preds
[
ii
]
=
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
preds
[
ii
]
=
col_predicate
&&
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
...
@@ -140,7 +142,7 @@ struct Gmem_tile_qkv {
...
@@ -140,7 +142,7 @@ struct Gmem_tile_qkv {
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char
*
ptr_
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
char
*
ptr_
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
if
(
col_predicate
&&
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
))
{
fmha
::
stg
(
ptr_
,
data
[
ii
]);
fmha
::
stg
(
ptr_
,
data
[
ii
]);
}
}
}
}
...
@@ -154,7 +156,7 @@ struct Gmem_tile_qkv {
...
@@ -154,7 +156,7 @@ struct Gmem_tile_qkv {
using
elem2_type
=
typename
std
::
conditional
<
std
::
is_same
<
elem_type
,
__half
>::
value
,
__half2
,
__nv_bfloat162
>::
type
;
using
elem2_type
=
typename
std
::
conditional
<
std
::
is_same
<
elem_type
,
__half
>::
value
,
__half2
,
__nv_bfloat162
>::
type
;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type
*
ptr_
=
reinterpret_cast
<
elem2_type
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
elem2_type
*
ptr_
=
reinterpret_cast
<
elem2_type
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
if
(
col_predicate
&&
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
))
{
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
atomicAdd
(
ptr_
+
jj
,
reinterpret_cast
<
const
elem2_type
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
atomicAdd
(
ptr_
+
jj
,
reinterpret_cast
<
const
elem2_type
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
...
@@ -172,7 +174,7 @@ struct Gmem_tile_qkv {
...
@@ -172,7 +174,7 @@ struct Gmem_tile_qkv {
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float
*
ptr_
=
reinterpret_cast
<
float
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
float
*
ptr_
=
reinterpret_cast
<
float
*>
(
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
);
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
if
(
col_predicate
&&
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
))
{
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
const
float2
data_f
=
fmha
::
half2_unpack
<
__half
>
(
reinterpret_cast
<
const
uint32_t
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
const
float2
data_f
=
fmha
::
half2_unpack
<
__half
>
(
reinterpret_cast
<
const
uint32_t
(
&
)[
4
]
>
(
data
[
ii
])[
jj
]);
...
@@ -201,6 +203,7 @@ struct Gmem_tile_qkv {
...
@@ -201,6 +203,7 @@ struct Gmem_tile_qkv {
const
int
tidx_
;
const
int
tidx_
;
// The length of the sequence loaded by that memory tile.
// The length of the sequence loaded by that memory tile.
int
actual_seqlen
;
int
actual_seqlen
;
const
bool
col_predicate
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -246,11 +249,13 @@ struct Gmem_tile_o {
...
@@ -246,11 +249,13 @@ struct Gmem_tile_o {
template
<
typename
BInfo
>
template
<
typename
BInfo
>
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_stride_in_elts
,
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
const
uint32_t
head_stride_in_elts
,
const
int
headdim
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen_q
(
binfo
.
actual_seqlen_q
)
,
actual_seqlen_q
(
binfo
.
actual_seqlen_q
)
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
,
col_predicate
((
tidx
%
THREADS_PER_ROW
)
*
(
BYTES_PER_STG
/
BYTES_PER_ELEMENT
)
<
headdim
)
{
// Compute the position in the sequence (within the CTA for the moment).
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
int
row
=
tidx
/
THREADS_PER_ROW
;
...
@@ -280,7 +285,7 @@ struct Gmem_tile_o {
...
@@ -280,7 +285,7 @@ struct Gmem_tile_o {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
)
{
if
((
!
col_predicate
)
||
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
)
)
{
break
;
break
;
}
}
...
@@ -308,7 +313,7 @@ struct Gmem_tile_o {
...
@@ -308,7 +313,7 @@ struct Gmem_tile_o {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
)
{
if
((
!
col_predicate
)
||
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
)
)
{
break
;
break
;
}
}
...
@@ -335,6 +340,7 @@ struct Gmem_tile_o {
...
@@ -335,6 +340,7 @@ struct Gmem_tile_o {
// The length of the sequence loaded by that memory tile.
// The length of the sequence loaded by that memory tile.
int
actual_seqlen_q
;
int
actual_seqlen_q
;
const
int
tidx_
;
const
int
tidx_
;
const
bool
col_predicate
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
46fd2a20
...
@@ -138,19 +138,24 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -138,19 +138,24 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
dq_row_stride_in_elts
,
params
.
dq_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
dq_row_stride_in_elts
,
params
.
dq_head_stride_in_elts
,
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
params
.
d
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -160,7 +165,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -160,7 +165,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -172,7 +178,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -172,7 +178,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -702,7 +709,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -702,7 +709,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
gmem_dv
.
move
(
loop_step_idx
);
}
}
...
@@ -713,7 +721,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -713,7 +721,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
// }
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
}
}
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
46fd2a20
...
@@ -97,10 +97,13 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -97,10 +97,13 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
params
.
d
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
...
@@ -122,9 +125,11 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -122,9 +125,11 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
46fd2a20
...
@@ -105,32 +105,19 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, co
...
@@ -105,32 +105,19 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, co
// work around for MSVC issue
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
==
16
)
{
if
(
params
.
d
<=
32
)
{
if
(
params
.
seqlen_k
==
128
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
else
if
(
params
.
d
==
32
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
}
else
if
(
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
}
else
if
(
params
.
d
=
=
64
)
{
}
else
if
(
params
.
d
<
=
64
)
{
if
(
params
.
seqlen_k
==
128
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
}
else
if
(
params
.
seqlen_k
>=
256
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// This speeds things up by 2-3% by avoiding register spills, but it
...
@@ -146,7 +133,7 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, co
...
@@ -146,7 +133,7 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, co
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
}
}
}
else
if
(
params
.
d
=
=
128
)
{
}
else
if
(
params
.
d
<
=
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
46fd2a20
...
@@ -144,19 +144,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -144,19 +144,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
dq_row_stride_in_elts
,
params
.
dq_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
dq_row_stride_in_elts
,
params
.
dq_head_stride_in_elts
,
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
params
.
d
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -166,7 +171,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -166,7 +171,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -178,7 +184,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -178,7 +184,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -657,7 +664,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -657,7 +664,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
...
@@ -674,7 +682,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -674,7 +682,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
smem_dk
.
load
(
dk_out
);
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
46fd2a20
...
@@ -114,36 +114,23 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
...
@@ -114,36 +114,23 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
<=
32
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
}
else
if
(
launch_params
.
params
.
d
=
=
64
)
{
}
else
if
(
launch_params
.
params
.
d
<
=
64
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
}
else
if
(
launch_params
.
params
.
d
=
=
128
)
{
}
else
if
(
launch_params
.
params
.
d
<
=
128
)
{
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
46fd2a20
...
@@ -259,10 +259,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -259,10 +259,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_tmp_row_stride_in_elts
,
params
.
o_tmp_head_stride_in_elts
,
binfo
,
tidx
);
params
.
d
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_tmp_row_stride_in_elts
,
params
.
o_tmp_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
...
@@ -293,9 +296,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -293,9 +296,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
...
...
flash_attn/flash_attn_interface.py
View file @
46fd2a20
...
@@ -6,8 +6,8 @@ import flash_attn_cuda
...
@@ -6,8 +6,8 @@ import flash_attn_cuda
def
_get_block_size
(
device
,
head_dim
,
is_dropout
):
def
_get_block_size
(
device
,
head_dim
,
is_dropout
):
assert
head_dim
in
[
16
,
32
,
64
,
128
]
assert
head_dim
%
8
==
0
and
head_dim
<=
128
return
256
if
head_dim
in
[
16
,
32
,
64
]
else
128
return
256
if
head_dim
<=
64
else
128
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
...
...
tests/test_flash_attn.py
View file @
46fd2a20
...
@@ -340,7 +340,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
...
@@ -340,7 +340,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
80
,
40
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
...
@@ -362,8 +362,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -362,8 +362,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
*
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
*
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
'random'
)
#
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
#
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
'full'
)
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
x
,
Wqkv
,
nheads
,
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
x
,
Wqkv
,
nheads
,
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
...
@@ -395,7 +395,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -395,7 +395,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
...
@@ -421,7 +421,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -421,7 +421,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.98
<=
dropout_fraction
/
dropout_p
<=
1.02
assert
0.98
<=
dropout_fraction
/
dropout_p
<=
1.02
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
4
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
4
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...
@@ -430,7 +430,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -430,7 +430,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
80
,
40
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
...
@@ -487,7 +487,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -487,7 +487,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dq_unpad
,
dkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
kv_unpad
),
g
)
dq_unpad
,
dkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
kv_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
...
@@ -512,7 +512,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -512,7 +512,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dkv
-
dkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dkv_pt
-
dkv_ref
).
abs
().
max
().
item
()
assert
(
dkv
-
dkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dkv_pt
-
dkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
...
@@ -522,7 +522,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
...
@@ -522,7 +522,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
80
,
40
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
...
@@ -579,7 +579,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
...
@@ -579,7 +579,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
...
@@ -605,7 +605,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
...
@@ -605,7 +605,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -618,7 +618,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
...
@@ -618,7 +618,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
80
,
40
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
,
0.17
])
...
@@ -681,7 +681,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
...
@@ -681,7 +681,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
...
@@ -707,7 +707,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
...
@@ -707,7 +707,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
else
:
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...
@@ -715,7 +715,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
...
@@ -715,7 +715,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
,
64
,
80
,
40
,
32
,
16
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
,
257
,
384
,
512
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('seqlen', [128])
...
@@ -749,7 +749,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
...
@@ -749,7 +749,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
S_dmask_0
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
S_dmask_0
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
g
=
torch
.
randn_like
(
output_unpad_0
)
g
=
torch
.
randn_like
(
output_unpad_0
)
dq_unpad_0
,
dk_unpad_0
,
dv_unpad_0
,
=
torch
.
autograd
.
grad
(
output_unpad_0
,
dq_unpad_0
,
dk_unpad_0
,
dv_unpad_0
,
=
torch
.
autograd
.
grad
(
output_unpad_0
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
...
@@ -768,7 +768,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
...
@@ -768,7 +768,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
# assert torch.equal(sm_lse, sm_lse_0)
# assert torch.equal(sm_lse, sm_lse_0)
assert
torch
.
equal
(
S_dmask_converted
,
S_dmask_converted_0
)
assert
torch
.
equal
(
S_dmask_converted
,
S_dmask_converted_0
)
if
is_sm80
or
d
<
128
:
# Only run backward for d=128 on A100
if
is_sm80
or
d
<
=
64
:
# Only run backward for d=128 on A100
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output_unpad
,
dq_unpad
,
dk_unpad
,
dv_unpad
,
=
torch
.
autograd
.
grad
(
output_unpad
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
assert
torch
.
equal
(
dq_unpad
,
dq_unpad_0
)
assert
torch
.
equal
(
dq_unpad
,
dq_unpad_0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment