Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6c3a8c65
Commit
6c3a8c65
authored
Jun 30, 2022
by
Tri Dao
Browse files
Implement cross attention
parent
01947bc9
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
815 additions
and
486 deletions
+815
-486
README.md
README.md
+5
-4
benchmarks/benchmark_flash_attention.py
benchmarks/benchmark_flash_attention.py
+4
-2
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+424
-273
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+39
-22
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+13
-13
csrc/flash_attn/src/fmha/mask.h
csrc/flash_attn/src/fmha/mask.h
+7
-8
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
.../flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
+6
-6
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+15
-15
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
+6
-6
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+10
-9
csrc/flash_attn/src/fmha_blockmask.h
csrc/flash_attn/src/fmha_blockmask.h
+1
-1
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+15
-15
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+15
-16
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+23
-23
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+10
-9
csrc/flash_attn/src/fmha_kernel.h
csrc/flash_attn/src/fmha_kernel.h
+9
-7
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+16
-12
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+197
-45
No files found.
README.md
View file @
6c3a8c65
...
@@ -32,10 +32,11 @@ Our tentative roadmap:
...
@@ -32,10 +32,11 @@ Our tentative roadmap:
3.
[Jun 2022] Refactor to use Cutlass.
3.
[Jun 2022] Refactor to use Cutlass.
4.
~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4.
~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5.
[Jun 2022] Support bf16.
5.
[Jun 2022] Support bf16.
6.
~~[Jul 2022] Support head dimension 128~~[Done].
6.
~~[Jul 2022] Implement cross-attention~~[Done].
7.
[Jul 2022] Support SM70 GPUs (V100).
7.
~~[Jul 2022] Support head dimension 128~~[Done].
8.
[Aug 2022] Fuse rotary embedding.
8.
[Jul 2022] Support SM70 GPUs (V100).
9.
[Aug 2022] Support Attention linear bias (e.g. ALiBi).
9.
[Aug 2022] Fuse rotary embedding.
10.
[Aug 2022] Support Attention linear bias (e.g. ALiBi).
## Speedup and Memory Savings
## Speedup and Memory Savings
...
...
benchmarks/benchmark_flash_attention.py
View file @
6c3a8c65
...
@@ -8,7 +8,7 @@ from einops import rearrange, repeat
...
@@ -8,7 +8,7 @@ from einops import rearrange, repeat
from
benchmarks.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
benchmarks.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_
unpadded_qkvpacked_
func
def
attention_ref
(
qkv
,
attn_mask
,
dropout_p
,
upcast
=
False
,
causal
=
False
):
def
attention_ref
(
qkv
,
attn_mask
,
dropout_p
,
upcast
=
False
,
causal
=
False
):
...
@@ -62,7 +62,9 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
...
@@ -62,7 +62,9 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h
=
nheads
).
detach
().
requires_grad_
()
h
=
nheads
).
detach
().
requires_grad_
()
qkv
=
rearrange
(
Wqkv
(
x
),
'b s (t h d) -> b s t h d'
,
t
=
3
,
h
=
nheads
).
detach
().
requires_grad_
()
qkv
=
rearrange
(
Wqkv
(
x
),
'b s (t h d) -> b s t h d'
,
t
=
3
,
h
=
nheads
).
detach
().
requires_grad_
()
fn
=
lambda
qkv_unpad
:
flash_attn_func
(
qkv_unpad
,
cu_seqlens
,
dropout_p
,
max_seqlen_in_batch
,
causal
=
causal
)
fn
=
lambda
qkv_unpad
:
flash_attn_unpadded_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen_in_batch
,
dropout_p
,
causal
=
causal
)
benchmark_all
(
fn
,
qkv_unpad
,
repeats
=
repeats
,
desc
=
'FlashAttention'
)
benchmark_all
(
fn
,
qkv_unpad
,
repeats
=
repeats
,
desc
=
'FlashAttention'
)
fn
=
lambda
qkv
:
attention_ref
(
qkv
,
attention_mask_bool
,
dropout_p
,
causal
=
causal
)
fn
=
lambda
qkv
:
attention_ref
(
qkv
,
attention_mask_bool
,
dropout_p
,
causal
=
causal
)
benchmark_all
(
fn
,
qkv
,
repeats
=
repeats
,
desc
=
'PyTorch Standard Attention'
)
benchmark_all
(
fn
,
qkv
,
repeats
=
repeats
,
desc
=
'PyTorch Standard Attention'
)
csrc/flash_attn/fmha_api.cpp
View file @
6c3a8c65
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/fmha.h
View file @
6c3a8c65
...
@@ -42,9 +42,8 @@
...
@@ -42,9 +42,8 @@
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
THREE_DIM
=
1
;
constexpr
int
H_DIM
=
1
;
constexpr
int
H_DIM
=
2
;
constexpr
int
D_DIM
=
2
;
constexpr
int
D_DIM
=
3
;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -72,10 +71,7 @@ struct Qkv_params {
...
@@ -72,10 +71,7 @@ struct Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fused_multihead_attention_fprop_params
:
public
Qkv_params
{
struct
FMHA_fprop_params
:
public
Qkv_params
{
// The dQKV matrices.
void
*
__restrict__
dqkv_ptr
;
// The O matrix (output).
// The O matrix (output).
void
*
__restrict__
o_ptr
;
void
*
__restrict__
o_ptr
;
...
@@ -90,10 +86,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -90,10 +86,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// the loop;
// the loop;
void
*
__restrict__
o_tmp_ptr
;
void
*
__restrict__
o_tmp_ptr
;
// The dO matrix .
// The pointer to the S matrix.
void
*
__restrict__
do_ptr
;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void
*
__restrict__
s_ptr
;
void
*
__restrict__
s_ptr
;
// The stride between rows of the S matrix.
// The stride between rows of the S matrix.
// int64_t s_stride_in_bytes;
// int64_t s_stride_in_bytes;
...
@@ -102,18 +95,16 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -102,18 +95,16 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The pointer to the softmax sum.
// The pointer to the softmax sum.
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lse_ptr
;
// The pointer to the softmax d sum.
void
*
__restrict__
dsoftmax_sum
;
// The dimensions.
// The dimensions.
int
b
,
s
,
d
;
int
b
,
s
eqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounde
d
;
// The scaling factors for the kernel.
// The scaling factors for the kernel.
float
scale_bmm1f
;
float
scale_bmm1f
;
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
uint32_t
scale_bmm1
;
// array of length b+1 holding starting offset of each sequence.
// array of length b+1 holding starting offset of each sequence.
int
*
__restrict__
cu_seqlens
;
int
*
__restrict__
cu_seqlens_q
;
int
*
__restrict__
cu_seqlens_k
;
int
*
__restrict__
blockmask
;
int
*
__restrict__
blockmask
;
...
@@ -136,7 +127,33 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -136,7 +127,33 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_params
>
struct
FMHA_dgrad_params
:
public
FMHA_fprop_params
{
// The dQKV matrices.
void
*
__restrict__
dq_ptr
;
void
*
__restrict__
dk_ptr
;
void
*
__restrict__
dv_ptr
;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t
dq_row_stride_in_elts
;
uint32_t
dk_row_stride_in_elts
;
uint32_t
dv_row_stride_in_elts
;
uint32_t
dq_head_stride_in_elts
;
uint32_t
dk_head_stride_in_elts
;
uint32_t
dv_head_stride_in_elts
;
// The dO matrix. We assume it is contiguous.
void
*
__restrict__
do_ptr
;
// The pointer to the softmax d sum.
void
*
__restrict__
dsoftmax_sum
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_params
>
struct
Launch_params
{
struct
Launch_params
{
Launch_params
(
cudaDeviceProp
*
props_
,
Launch_params
(
cudaDeviceProp
*
props_
,
cudaStream_t
stream_
,
cudaStream_t
stream_
,
...
@@ -168,10 +185,10 @@ struct Launch_params{
...
@@ -168,10 +185,10 @@ struct Launch_params{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_sm80
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_fp16_sm80
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_dgrad_fp16_sm80
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_sm80
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_block_dgrad_fp16_sm80
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_block_dgrad_fp16_sm80
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
6c3a8c65
...
@@ -63,9 +63,9 @@ struct Gmem_tile_qkv {
...
@@ -63,9 +63,9 @@ struct Gmem_tile_qkv {
// Ctor.
// Ctor.
template
<
typename
BInfo
>
template
<
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
,
bool
use_seqlen_q
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
actual_seqlen
(
use_seqlen_q
?
binfo
.
actual_seqlen
_q
:
binfo
.
actual_seqlen_k
)
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr_
))
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr_
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
{
...
@@ -80,7 +80,7 @@ struct Gmem_tile_qkv {
...
@@ -80,7 +80,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t
row_offset
=
(
uint32_t
)((
binfo
.
sum_s
+
row
)
*
row_stride_in_bytes
);
uint32_t
row_offset
=
(
uint32_t
)((
(
use_seqlen_q
?
binfo
.
sum_s_q
:
binfo
.
sum_s_k
)
+
row
)
*
row_stride_in_bytes
);
// Add the block index.
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
...
@@ -193,7 +193,7 @@ struct Gmem_tile_o {
...
@@ -193,7 +193,7 @@ struct Gmem_tile_o {
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_stride_in_elts
,
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
actual_seqlen
_q
(
binfo
.
actual_seqlen
_q
)
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
{
...
@@ -207,7 +207,7 @@ struct Gmem_tile_o {
...
@@ -207,7 +207,7 @@ struct Gmem_tile_o {
// The row offset in the batched GEMM.
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t
row_offset
=
(
uint32_t
)((
binfo
.
sum_s
+
row
)
*
row_stride_in_bytes
);
uint32_t
row_offset
=
(
uint32_t
)((
binfo
.
sum_s
_q
+
row
)
*
row_stride_in_bytes
);
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
// Assemble the final pointer.
// Assemble the final pointer.
ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
...
@@ -224,7 +224,7 @@ struct Gmem_tile_o {
...
@@ -224,7 +224,7 @@ struct Gmem_tile_o {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
)
{
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
_q
)
{
break
;
break
;
}
}
...
@@ -252,7 +252,7 @@ struct Gmem_tile_o {
...
@@ -252,7 +252,7 @@ struct Gmem_tile_o {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
)
{
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
_q
)
{
break
;
break
;
}
}
...
@@ -266,7 +266,7 @@ struct Gmem_tile_o {
...
@@ -266,7 +266,7 @@ struct Gmem_tile_o {
// row_ += ROWS * steps;
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
ptr_
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
actual_seqlen
_q
-=
ROWS
*
steps
;
}
}
// The stride between rows for the QKV matrice.
// The stride between rows for the QKV matrice.
...
@@ -277,7 +277,7 @@ struct Gmem_tile_o {
...
@@ -277,7 +277,7 @@ struct Gmem_tile_o {
// Is the thread active for the last STG?
// Is the thread active for the last STG?
int
is_active_for_last_stg_
;
int
is_active_for_last_stg_
;
// The length of the sequence loaded by that memory tile.
// The length of the sequence loaded by that memory tile.
int
actual_seqlen
;
int
actual_seqlen
_q
;
const
int
tidx_
;
const
int
tidx_
;
};
};
...
@@ -319,8 +319,8 @@ struct Gmem_tile_mma_sd {
...
@@ -319,8 +319,8 @@ struct Gmem_tile_mma_sd {
uint32_t
bidx
=
bidb
*
params
.
h
+
bidh
;
uint32_t
bidx
=
bidb
*
params
.
h
+
bidh
;
// The distance between two blocks (in bytes).
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT;
// const size_t block_stride_bytes = params.s
eqlen_q
* params.s
eqlen_k
* BYTES_PER_ELEMENT;
const
uint32_t
block_stride_bytes
=
params
.
s
*
params
.
s
*
BYTES_PER_ELEMENT
;
const
uint32_t
block_stride_bytes
=
params
.
s
eqlen_q
*
params
.
s
eqlen_k
*
BYTES_PER_ELEMENT
;
// Set store location for each thread at the beginning of the loop
// Set store location for each thread at the beginning of the loop
ptr_
+=
bidx
*
block_stride_bytes
+
tidx
*
BYTES_PER_STG
;
ptr_
+=
bidx
*
block_stride_bytes
+
tidx
*
BYTES_PER_STG
;
}
}
...
@@ -468,8 +468,8 @@ struct Gmem_summary_stats {
...
@@ -468,8 +468,8 @@ struct Gmem_summary_stats {
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
// The distance between two blocks (in bytes).
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.s * BYTES_PER_ELEMENT;
// size_t block_stride_bytes = params.s
eqlen_q
* BYTES_PER_ELEMENT;
uint32_t
block_stride_bytes
=
params
.
s
*
BYTES_PER_ELEMENT
;
uint32_t
block_stride_bytes
=
params
.
s
eqlen_q
*
BYTES_PER_ELEMENT
;
// Set store location for each thread at the beginning of the loop
// Set store location for each thread at the beginning of the loop
ptr_row_
=
ptr_
+
bidx
*
block_stride_bytes
;
ptr_row_
=
ptr_
+
bidx
*
block_stride_bytes
;
...
...
csrc/flash_attn/src/fmha/mask.h
View file @
6c3a8c65
...
@@ -35,8 +35,8 @@ struct Mask {
...
@@ -35,8 +35,8 @@ struct Mask {
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
template
<
typename
BInfo
>
template
<
typename
BInfo
>
__device__
Mask
(
const
BInfo
&
b
lockI
nfo
,
int
tidx
,
const
int
loop_step_idx_
=
0
)
__device__
Mask
(
const
BInfo
&
b
i
nfo
,
int
tidx
,
const
int
loop_step_idx_
=
0
)
:
actual_seqlen
(
blockI
nfo
.
actual_seqlen
-
loop_step_idx_
*
Cta_tile
::
N
)
:
actual_seqlen
_k
(
bi
nfo
.
actual_seqlen
_k
-
loop_step_idx_
*
Cta_tile
::
N
)
,
loop_step_idx
(
loop_step_idx_
)
{
,
loop_step_idx
(
loop_step_idx_
)
{
const
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
const
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
...
@@ -60,12 +60,11 @@ struct Mask {
...
@@ -60,12 +60,11 @@ struct Mask {
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const
int
current_col
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
(
jj
&
2
)
*
4
+
(
jj
&
1
);
const
int
current_col
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
(
jj
&
2
)
*
4
+
(
jj
&
1
);
const
int
current_row
=
row_offset
+
ii
*
8
;
const
int
current_row
=
row_offset
+
ii
*
8
;
const
bool
col_valid
=
current_col
<
actual_seqlen
;
const
bool
col_valid
=
current_col
<
actual_seqlen_k
;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
bool
all_valid
=
Is_causal
?
col_valid
&&
(
current_col
+
loop_step_idx
*
Cta_tile
::
N
<=
current_row
)
:
col_valid
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("current_col=%d, current_row=%d, actual_seqlen=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen, col_valid, all_valid);
// printf("current_col=%d, current_row=%d, actual_seqlen
_k
=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen
_k
, col_valid, all_valid);
// }
// }
return
Is_causal
?
col_valid
&&
(
current_col
+
loop_step_idx
*
Cta_tile
::
N
<=
current_row
)
:
col_valid
;
return
Is_causal
?
col_valid
&&
(
current_col
+
loop_step_idx
*
Cta_tile
::
N
<=
current_row
)
:
col_valid
;
// return row_valid && col_valid;
// return row_valid && col_valid;
...
@@ -84,7 +83,7 @@ struct Mask {
...
@@ -84,7 +83,7 @@ struct Mask {
int
row
;
int
row
;
int
col
;
int
col
;
const
int
loop_step_idx
;
const
int
loop_step_idx
;
const
int
actual_seqlen
;
const
int
actual_seqlen
_k
;
};
};
}
// namespace fmha
}
// namespace fmha
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
View file @
6c3a8c65
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
#include "fmha_block_dgrad_kernel_1xN_loop.h"
#include "fmha_block_dgrad_kernel_1xN_loop.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
F
used_multihead_attention_fprop
_params
params
)
{
__global__
void
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
F
MHA_dgrad
_params
params
)
{
fmha
::
compute_block_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
fmha
::
compute_block_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
void
run_fmha_block_dgrad_fp16_sm80_loop_
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_block_dgrad_fp16_sm80_loop_
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
...
@@ -30,12 +30,12 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
...
@@ -30,12 +30,12 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
auto
kernel
=
is_dropout
auto
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
kernel
=
is_dropout
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
}
else
if
(
params
.
s
==
N
*
2
)
{
}
else
if
(
params
.
s
eqlen_k
==
blocksize_c
*
2
)
{
kernel
=
is_dropout
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
...
@@ -50,7 +50,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
...
@@ -50,7 +50,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
void
run_fmha_block_dgrad_fp16_sm80
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_block_dgrad_fp16_sm80
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
d
==
16
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_block_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_block_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
6c3a8c65
...
@@ -138,9 +138,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -138,9 +138,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
.
dq
kv
_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
d
q_row_stride_in_elts
,
params
.
d
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
...
@@ -148,9 +148,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -148,9 +148,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -160,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -160,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -172,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -172,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -181,7 +181,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -181,7 +181,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
steps
=
params
.
s
/
Cta_tile_p
::
M
;
const
int
steps
=
(
params
.
s
eqlen_q
+
Cta_tile_p
::
M
-
1
)
/
Cta_tile_p
::
M
;
// Wind gmem tiles to the correct position.
// Wind gmem tiles to the correct position.
int
block_row_idx_next
=
mask_val
/
4
;
int
block_row_idx_next
=
mask_val
/
4
;
...
@@ -316,7 +316,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -316,7 +316,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx);
// printf("block_row_idx = %d\n", block_row_idx);
// }
// }
if
(
block_row_idx
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
if
(
block_row_idx
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
_q
)
break
;
int
mask_val_next
=
l
<
steps
-
1
?
blockmask
.
mask_val
(
l
+
1
)
:
-
1
;
int
mask_val_next
=
l
<
steps
-
1
?
blockmask
.
mask_val
(
l
+
1
)
:
-
1
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
@@ -629,7 +629,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -629,7 +629,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
const
bool
is_final_write
=
const
bool
is_final_write
=
Is_last
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
_k
)
||
((
mask_val
&
0x2
)
!=
0
)
||
((
mask_val
&
0x2
)
!=
0
)
||
((
Is_causal
)
&&
(
block_row_idx
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
||
((
Is_causal
)
&&
(
block_row_idx
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
if
(
is_final_write
)
{
if
(
is_final_write
)
{
...
@@ -702,7 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -702,7 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
d
qk
v_ptr
+
2
*
params
.
h
*
params
.
d
*
2
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
d
v_row_stride_in_elts
,
params
.
d
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
gmem_dv
.
move
(
loop_step_idx
);
}
}
...
@@ -713,7 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -713,7 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
// }
Gmem_tile_dk
gmem_dk
(
params
.
d
qkv
_ptr
+
params
.
h
*
params
.
d
*
2
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dk
gmem_dk
(
params
.
d
k
_ptr
,
params
.
d
k_row_stride_in_elts
,
params
.
d
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
}
}
...
@@ -722,11 +722,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -722,11 +722,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N.
// loop_steps = -1 means the number of steps will be params.s
eqlen_k
/ Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
,
typename
Params
>
inline
__device__
void
compute_block_dq_dk_dv_1xN
(
const
Params
&
params
)
{
inline
__device__
void
compute_block_dq_dk_dv_1xN
(
const
Params
&
params
)
{
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
// The block index for the batch.
// The block index for the batch.
const
int
bidb
=
blockIdx
.
x
;
const
int
bidb
=
blockIdx
.
x
;
...
@@ -745,10 +745,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) {
...
@@ -745,10 +745,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) {
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
1
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
1
);
}
else
{
}
else
{
if
(
params
.
s
==
N_per_loop
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
const
int
max_loop_steps
=
(
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
>
(
params
,
ph
,
loop_step_idx
);
compute_block_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
>
(
params
,
ph
,
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
View file @
6c3a8c65
...
@@ -29,12 +29,12 @@
...
@@ -29,12 +29,12 @@
#include "fmha_block_fprop_kernel_1xN.h"
#include "fmha_block_fprop_kernel_1xN.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
__global__
void
fmha_block_fprop_fp16_sm80_loop_kernel
(
F
used_multihead_attention
_fprop_params
params
)
{
__global__
void
fmha_block_fprop_fp16_sm80_loop_kernel
(
F
MHA
_fprop_params
params
)
{
fmha
::
device_block_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
fmha
::
device_block_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
void
run_fmha_block_fp16_sm80_loop_
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
void
run_fmha_block_fp16_sm80_loop_
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
bool
is_causal
=
launch_params
.
params
.
is_causal
;
bool
is_causal
=
launch_params
.
params
.
is_causal
;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
...
@@ -46,8 +46,8 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
...
@@ -46,8 +46,8 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
?
(
launch_params
.
return_softmax
?
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
?
(
launch_params
.
return_softmax
?
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
:
(
launch_params
.
return_softmax
?
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_block_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
s
+
N
-
1
)
/
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
...
@@ -60,7 +60,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
...
@@ -60,7 +60,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
if
(
configure
)
{
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
size_t
STEPS
=
(
launch_params
.
params
.
s
+
M
-
1
)
/
M
;
size_t
STEPS
=
(
launch_params
.
params
.
s
eqlen_q
+
M
-
1
)
/
M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
...
@@ -75,7 +75,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
...
@@ -75,7 +75,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
void
run_fmha_block_fp16_sm80
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
void
run_fmha_block_fp16_sm80
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
6c3a8c65
...
@@ -97,7 +97,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -97,7 +97,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
...
@@ -122,9 +122,9 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -122,9 +122,9 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -206,7 +206,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -206,7 +206,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx);
// printf("block_row_idx = %d\n", block_row_idx);
// }
// }
if
(
block_row_idx
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
if
(
block_row_idx
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
_q
)
break
;
int
mask_val_next
=
l
<
steps
-
1
?
blockmask
.
mask_val
(
l
+
1
)
:
-
1
;
int
mask_val_next
=
l
<
steps
-
1
?
blockmask
.
mask_val
(
l
+
1
)
:
-
1
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
@@ -443,7 +443,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -443,7 +443,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
const
bool
is_final_write
=
const
bool
is_final_write
=
Is_last
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
_k
)
||
((
mask_val
&
0x2
)
!=
0
)
||
((
mask_val
&
0x2
)
!=
0
)
||
((
Is_causal
)
&&
(
block_row_idx
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
||
((
Is_causal
)
&&
(
block_row_idx
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
@@ -507,13 +507,14 @@ inline __device__ void device_block_1xN_loop(const Params ¶ms) {
...
@@ -507,13 +507,14 @@ inline __device__ void device_block_1xN_loop(const Params ¶ms) {
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph0
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph0
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph1
(
std
::
get
<
0
>
(
seeds
),
tidx_global
+
blockDim
.
x
,
std
::
get
<
1
>
(
seeds
));
Philox
ph1
(
std
::
get
<
0
>
(
seeds
),
tidx_global
+
blockDim
.
x
,
std
::
get
<
1
>
(
seeds
));
const
int
STEPS
=
params
.
s
/
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
STEPS
=
(
params
.
seqlen_q
+
M
-
1
)
/
M
;
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N_per_loop
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
0
);
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
const
int
max_loop_steps
=
(
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
0
);
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
loop_step_idx
);
fmha
::
device_block_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph0
,
ph1
,
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_blockmask.h
View file @
6c3a8c65
...
@@ -42,7 +42,7 @@ struct Blockmask {
...
@@ -42,7 +42,7 @@ struct Blockmask {
template
<
typename
Params
>
template
<
typename
Params
>
__device__
Blockmask
(
const
Params
&
params
,
int
loop_step_idx
)
:
__device__
Blockmask
(
const
Params
&
params
,
int
loop_step_idx
)
:
blockmask_ptr
(
params
.
blockmask
+
loop_step_idx
*
params
.
s
/
16
)
{
blockmask_ptr
(
params
.
blockmask
+
loop_step_idx
*
params
.
s
eqlen_q
/
16
)
{
}
}
__device__
int
mask_val
(
int
block_row_idx
)
const
{
__device__
int
mask_val
(
int
block_row_idx
)
const
{
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
6c3a8c65
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
#include "fmha_dgrad_kernel_1xN_loop.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
F
used_multihead_attention_fprop
_params
params
)
{
__global__
void
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
F
MHA_dgrad
_params
params
)
{
fmha
::
compute_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
fmha
::
compute_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
void
run_fmha_dgrad_fp16_sm80_loop_
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80_loop_
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
...
@@ -28,18 +28,18 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
...
@@ -28,18 +28,18 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
auto
kernel
=
is_dropout
auto
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
kernel
=
is_dropout
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
}
else
if
(
params
.
s
==
N
*
2
)
{
}
else
if
(
params
.
s
eqlen_k
==
blocksize_c
*
2
)
{
kernel
=
is_dropout
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
}
}
// printf("
N
= %d, WARPS_N = %d, Smem size = %d\n",
N
, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
// printf("
blocksize_c
= %d, WARPS_N = %d, Smem size = %d\n",
blocksize_c
, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
@@ -49,12 +49,12 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
...
@@ -49,12 +49,12 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
void
run_fmha_dgrad_fp16_sm80
(
const
F
used_multihead_attention_fprop
_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80
(
const
F
MHA_dgrad
_params
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
s
==
128
)
{
if
(
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
==
256
)
{
}
else
if
(
params
.
s
eqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -64,18 +64,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
...
@@ -64,18 +64,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
}
else
if
(
params
.
d
==
32
)
{
}
else
if
(
params
.
d
==
32
)
{
if
(
params
.
s
==
128
)
{
if
(
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
>=
256
)
{
}
else
if
(
params
.
s
eqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
}
else
if
(
params
.
d
==
64
)
{
}
else
if
(
params
.
d
==
64
)
{
if
(
params
.
s
==
128
)
{
if
(
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
>=
256
)
{
}
else
if
(
params
.
s
eqlen_k
>=
256
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
// Don't share smem for K & V, and don't keep V in registers
...
@@ -102,10 +102,10 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
...
@@ -102,10 +102,10 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// } else {
// if( params.s == 128 ) {
// if( params.s
eqlen_k
== 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.s >= 256 ) {
// } else if( params.s
eqlen_k
>= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // This speeds things up by 2-3% by avoiding register spills, but it
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
6c3a8c65
...
@@ -131,9 +131,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -131,9 +131,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
.
dq
kv
_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dq_ptr
,
params
.
d
q_row_stride_in_elts
,
params
.
d
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
...
@@ -141,9 +141,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -141,9 +141,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -153,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -153,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -165,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -165,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -175,8 +175,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -175,8 +175,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
const
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
// constexpr int steps = Cta_tile_p::N / Cta_tile_p::M;
const
int
steps
=
(
params
.
seqlen_q
+
Cta_tile_p
::
M
-
1
)
/
Cta_tile_p
::
M
-
begin
;
const
int
steps
=
params
.
s
/
Cta_tile_p
::
M
-
begin
;
// Wind gmem tiles to the correct position.
// Wind gmem tiles to the correct position.
gmem_q
.
move
(
begin
);
gmem_q
.
move
(
begin
);
...
@@ -294,7 +293,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -294,7 +293,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
const
int
loop
=
(
begin
+
l
)
*
Cta_tile_p
::
M
;
const
int
loop
=
(
begin
+
l
)
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
if
(
loop
>=
binfo
.
actual_seqlen
_q
)
break
;
break
;
// Load the fragments for V.
// Load the fragments for V.
...
@@ -584,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -584,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
const
bool
is_final_write
=
const
bool
is_final_write
=
Is_last
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
_k
)
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
if
(
is_final_write
)
{
if
(
is_final_write
)
{
// if (Is_dropout) {
// if (Is_dropout) {
...
@@ -656,7 +655,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -656,7 +655,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
d
qk
v_ptr
+
2
*
params
.
h
*
params
.
d
*
2
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
d
v_row_stride_in_elts
,
params
.
d
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
gmem_dv
.
move
(
loop_step_idx
);
}
}
...
@@ -667,7 +666,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -667,7 +666,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
// }
Gmem_tile_dk
gmem_dk
(
params
.
d
qkv
_ptr
+
params
.
h
*
params
.
d
*
2
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dk
gmem_dk
(
params
.
d
k
_ptr
,
params
.
d
k_row_stride_in_elts
,
params
.
d
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
}
}
...
@@ -676,11 +675,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -676,11 +675,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N.
// loop_steps = -1 means the number of steps will be params.s
eqlen_k
/ Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1xN
(
const
Params
&
params
)
{
inline
__device__
void
compute_dq_dk_dv_1xN
(
const
Params
&
params
)
{
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
// The block index for the batch.
// The block index for the batch.
const
int
bidb
=
blockIdx
.
x
;
const
int
bidb
=
blockIdx
.
x
;
...
@@ -699,10 +698,10 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
...
@@ -699,10 +698,10 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
1
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
1
);
}
else
{
}
else
{
if
(
params
.
s
==
N_per_loop
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
const
int
max_loop_steps
=
(
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
>
(
params
,
ph
,
loop_step_idx
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
>
(
params
,
ph
,
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
6c3a8c65
...
@@ -29,12 +29,12 @@
...
@@ -29,12 +29,12 @@
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
__global__
void
fmha_fprop_fp16_sm80_loop_kernel
(
F
used_multihead_attention
_fprop_params
params
)
{
__global__
void
fmha_fprop_fp16_sm80_loop_kernel
(
F
MHA
_fprop_params
params
)
{
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
bool
is_causal
=
launch_params
.
params
.
is_causal
;
bool
is_causal
=
launch_params
.
params
.
is_causal
;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
...
@@ -46,8 +46,8 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
...
@@ -46,8 +46,8 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
s
+
N
-
1
)
/
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
...
@@ -60,7 +60,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
...
@@ -60,7 +60,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
if
(
configure
)
{
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
size_t
STEPS
=
(
launch_params
.
params
.
s
+
M
-
1
)
/
M
;
size_t
STEPS
=
(
launch_params
.
params
.
s
eqlen_q
+
M
-
1
)
/
M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
...
@@ -75,13 +75,13 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
...
@@ -75,13 +75,13 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
void
run_fmha_fp16_sm80
(
Launch_params
<
F
used_multihead_attention
_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80
(
Launch_params
<
F
MHA
_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
if
(
launch_params
.
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
}
else
if
(
launch_params
.
params
.
s
eqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
}
else
{
...
@@ -91,10 +91,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -91,10 +91,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
if
(
launch_params
.
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
}
else
if
(
launch_params
.
params
.
s
eqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
}
else
{
...
@@ -102,10 +102,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -102,10 +102,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
if
(
launch_params
.
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
>=
256
)
{
}
else
if
(
launch_params
.
params
.
s
eqlen_k
>=
256
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
...
@@ -121,7 +121,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -121,7 +121,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
}
}
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
if
(
launch_params
.
params
.
s
eqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
}
else
{
...
@@ -145,27 +145,27 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -145,27 +145,27 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) {
// if( launch_params.params.s == 128 ) {
// if( launch_params.params.s
eqlen_k
== 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.s >= 256 ) {
// } else if( launch_params.params.s
eqlen_k
>= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward
//
//
if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
//
//
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
//
//
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
//
//
} else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
//
//
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
//
//
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
//
//
}
// }
// }
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if (launch_params.params.d == 128) {
// if( launch_params.params.s == 128 ) {
// if( launch_params.params.s
eqlen_k
== 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// } else {
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
6c3a8c65
...
@@ -247,7 +247,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -247,7 +247,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
,
true
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
...
@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
,
false
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -354,7 +354,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -354,7 +354,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
_q
)
break
;
// Declare the accumulators for the 1st gemm.
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
...
@@ -575,7 +575,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -575,7 +575,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
const
bool
is_final_write
=
const
bool
is_final_write
=
Is_last
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
_k
)
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
...
@@ -631,13 +631,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
...
@@ -631,13 +631,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph0
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph0
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph1
(
std
::
get
<
0
>
(
seeds
),
tidx_global
+
blockDim
.
x
,
std
::
get
<
1
>
(
seeds
));
Philox
ph1
(
std
::
get
<
0
>
(
seeds
),
tidx_global
+
blockDim
.
x
,
std
::
get
<
1
>
(
seeds
));
const
int
STEPS
=
params
.
s
/
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
STEPS
=
(
params
.
seqlen_q
+
M
-
1
)
/
M
;
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N_per_loop
)
{
if
(
params
.
s
eqlen_k
==
blocksize_c
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
const
int
max_loop_steps
=
(
params
.
s
eqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
loop_step_idx
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_kernel.h
View file @
6c3a8c65
...
@@ -51,20 +51,22 @@ struct BlockInfoPadded {
...
@@ -51,20 +51,22 @@ struct BlockInfoPadded {
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
// The block index.
sum_s
=
params
.
cu_seqlens
[
bidb
];
sum_s_k
=
params
.
cu_seqlens_k
[
bidb
];
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
actual_seqlen_k
=
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
;
bidx
=
sum_s
*
params
.
h
+
bidh
;
sum_s_q
=
params
.
cu_seqlens_q
[
bidb
];
actual_seqlen_q
=
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
}
}
__device__
bool
stop_early
(
const
int
start_col
=
0
)
const
{
__device__
bool
stop_early
(
const
int
start_col
=
0
)
const
{
return
actual_seqlen
<=
start_col
;
return
actual_seqlen
_k
<=
start_col
;
}
}
int
actual_seqlen
;
int
actual_seqlen_q
;
int
bidx
;
int
actual_seqlen_k
;
int
sum_s
;
int
sum_s_q
;
int
sum_s_k
;
int
bidh
;
int
bidh
;
int
bidb
;
int
bidb
;
int
tidx_global
;
int
tidx_global
;
...
...
flash_attn/flash_attention.py
View file @
6c3a8c65
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
flash_attn.rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_
unpadded_qkvpacked_
func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
...
@@ -13,15 +13,15 @@ class FlashAttention(nn.Module):
...
@@ -13,15 +13,15 @@ class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
Arguments
Arguments
---------
---------
softmax_
temp
: The temperature to use for the softmax attention.
softmax_
scale
: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
runtime)
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
(default: 0.1)
"""
"""
def
__init__
(
self
,
softmax_
temp
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
softmax_
scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
softmax_
temp
=
softmax_
temp
self
.
softmax_
scale
=
softmax_
scale
self
.
dropout_p
=
attention_dropout
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
...
@@ -49,8 +49,10 @@ class FlashAttention(nn.Module):
...
@@ -49,8 +49,10 @@ class FlashAttention(nn.Module):
max_s
=
seqlen
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
device
=
qkv
.
device
)
output
=
flash_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
output
=
flash_attn_unpadded_qkvpacked_func
(
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
qkv
,
cu_seqlens
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
...
@@ -58,17 +60,19 @@ class FlashAttention(nn.Module):
...
@@ -58,17 +60,19 @@ class FlashAttention(nn.Module):
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
flash_attn_func
(
x_unpad
,
cu_seqlens
,
output_unpad
=
flash_attn_unpadded_qkvpacked_func
(
self
.
dropout_p
if
self
.
training
else
0.0
,
x_unpad
,
cu_seqlens
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
indices
,
batch_size
,
seqlen
),
indices
,
batch_size
,
seqlen
),
'b s (h d) -> b s h d'
,
h
=
nheads
)
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
else
:
assert
max_s
is
not
None
assert
max_s
is
not
None
output
=
flash_attn_func
(
qkv
,
cu_seqlens
,
output
=
flash_attn_unpadded_qkvpacked_func
(
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
return
output
,
None
return
output
,
None
...
...
flash_attn/flash_attn_interface.py
View file @
6c3a8c65
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
flash_attn_cuda
import
flash_attn_cuda
def
_flash_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
def
_get_block_size
(
device
,
head_dim
,
is_dropout
):
context
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
assert
head_dim
in
[
16
,
32
,
64
,
128
]
False
,
causal
,
return_softmax
,
None
)
if
head_dim
in
[
16
,
32
]:
# if context.isnan().any() or softmax_lse.isnan().any():
return
256
elif
head_dim
==
64
:
return
128
if
(
torch
.
cuda
.
get_device_capability
(
device
)
==
(
7
,
5
)
and
is_dropout
)
else
256
elif
head_dim
==
128
:
return
256
if
(
torch
.
cuda
.
get_device_capability
(
device
)
==
(
8
,
0
)
and
not
is_dropout
)
else
128
def
_flash_attn_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
out
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
contex
t
,
softmax_lse
,
S_dmask
return
ou
t
,
softmax_lse
,
S_dmask
def
_flash_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
max_s
,
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
softmax_scale
,
causal
):
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
flash_attn_cuda
.
bwd
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
softmax_d
=
flash_attn_cuda
.
bwd
(
softmax_scale
,
max_s
,
False
,
causal
,
None
)
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
# if dqkv.isnan().any() or softmax_d.isnan().any():
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
None
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
# breakpoint()
return
dq
kv
return
dq
,
dk
,
dv
,
softmax_d
class
FlashAttnFun
(
torch
.
autograd
.
Function
):
class
FlashAttn
QKVPacked
Fun
c
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
qkv
,
out
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
max_s
eqlen
=
max_s
eqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
return
context
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dout
):
def
backward
(
ctx
,
dout
,
*
args
):
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
qkv
,
out
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
torch
.
empty_like
(
qkv
)
dqkv
=
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
# We duplicate code to return both the output and the softmax for testing
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
FlashAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
# Save rng_state because the backward pass is gonna regenerate the dropout mask
softmax_scale
,
causal
,
return_softmax
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
q
,
kv
[:,
0
],
kv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
ctx
.
save_for_backward
(
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dq
=
torch
.
empty_like
(
q
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
dout
,
q
,
kv
[:,
0
],
kv
[:,
1
],
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
,
S_dmask
=
_flash_attn_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
)
)
ctx
.
save_for_backward
(
q
kv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens
_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
return
context
,
S_dmask
,
softmax_lse
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dout
,
_dS_dmask_ignored
,
_dsoftmax_sum_ignored
):
def
backward
(
ctx
,
dout
,
*
args
):
q
kv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens
_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_flash_attn_backward
(
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
_flash_attn_backward
(
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""For backward-compatibility only, will remove soon.
dropout_p should be set to 0.0 during evaluation
"""
"""
func
=
F
lash
A
ttn
F
un
if
not
return_attn_probs
else
FlashAttnFunWithS
return
f
lash
_a
ttn
_
un
padded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_s
,
dropout_p
,
softmax_scale
,
return
func
.
apply
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
causal
,
return_attn_probs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment