Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3524e13c
Commit
3524e13c
authored
Aug 13, 2023
by
Tri Dao
Browse files
Update to Cutlass 3.1
parent
364a5b4a
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
229 additions
and
171 deletions
+229
-171
csrc/cutlass
csrc/cutlass
+1
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+164
-118
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+39
-30
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+19
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+6
-6
No files found.
cutlass
@
6f474202
Compare
c4f6b8c6
...
6f474202
Subproject commit
c4f6b8c6bc94ff69048492fb34df0dfaf1983933
Subproject commit
6f47420213f757831fae65c686aa471749fa8d60
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
3524e13c
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
3524e13c
...
...
@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash
::
reduce_sum
(
scores
,
scores_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
copy
(
scores_max
,
scores_max_prev
);
cute
::
copy
(
scores_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
...
...
@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
>
inline
__device__
void
write_softmax_to_gmem
(
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_t
hr
_copy_P
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_t
iled
_copy_P
)
{
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout
l
=
tOrP
.
layout
();
...
...
@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tPrP
)
==
size
<
1
>
(
tPgP
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tPrP
);
++
mi
)
{
copy
(
gmem_t
hr
_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
cute
::
copy
(
gmem_t
iled
_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
}
};
...
...
@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
auto
gmem_thr_copy_QKV
=
typename
Kernel_traits
::
GmemTiledCopyQKV
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_P
=
typename
Kernel_traits
::
GmemTiledCopyP
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyP
gmem_tiled_copy_P
;
auto
gmem_thr_copy_P
=
gmem_tiled_copy_P
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
...
...
@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Copy Atom retiling
//
auto
smem_thr_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto
smem_thr_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
auto
smem_thr_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
...
...
@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
...
...
@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_t
hr
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
cute
::
copy
(
smem_t
iled
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
}
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...
...
@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_t
hr
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
cute
::
copy
(
smem_t
iled
_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
...
...
@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Advance gV
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
...
...
@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
n_block
>
0
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
...
...
@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
hr
_copy_P
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
iled
_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
...
...
@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
}
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
...
...
@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
__syncthreads
();
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
flash
::
cp_async_wait
<
0
>
();
...
...
@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
n_block
>
0
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
hr
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_t
iled
_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
...
...
@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
hr
_copy_P
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_t
iled
_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
...
...
@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
block_row_idx
,
block_col_idx
,
kNWarps
);
}
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
// Epilogue
...
...
@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
Tensor
sO
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto
smem_thr_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
).
get_thread_slice
(
tidx
);
auto
smem_tiled_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
...
...
@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// sO has the same size as sQ, so we don't need to sync here.
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
copy
(
smem_t
hr
_copy_O
,
taccOrO
,
taccOsO
);
cute
::
copy
(
smem_t
iled
_copy_O
,
taccOrO
,
taccOsO
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
...
...
@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
auto
gmem_thr_copy_O
=
typename
Kernel_traits
::
GmemTiledCopyO
{}.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
__syncthreads
();
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
copy
(
gmem_t
hr
_copy_O
,
tOsO
,
tOrO
);
cute
::
copy
(
gmem_t
iled
_copy_O
,
tOsO
,
tOrO
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma
.
partition_C
(
caccO
);
// (MMA,MMA_M,MMA_K)
...
...
@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_t
hr
_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_t
iled
_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
...
...
csrc/flash_attn/src/utils.h
View file @
3524e13c
...
...
@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) {
template
<
bool
A_in_regs
=
false
,
bool
B_in_regs
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
TiledMma
,
typename
TiledCopy0
,
typename
TiledCopy1
>
typename
TiledMma
,
typename
TiledCopyA
,
typename
TiledCopyB
,
typename
ThrCopyA
,
typename
ThrCopyB
>
inline
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy0
smem_thr_copy_A
,
TiledCopy1
smem_thr_copy_B
)
{
TiledCopyA
smem_tiled_copy_A
,
TiledCopyB
smem_tiled_copy_B
,
ThrCopyA
smem_thr_copy_A
,
ThrCopyB
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
...
...
@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
tCrA_copy_view
));
// M
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
if
(
!
A_in_regs
)
{
copy
(
smem_t
hr
_copy_A
,
tCsA
(
_
,
_
,
_0
{}),
tCrA_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
B_in_regs
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_A
,
tCsA
(
_
,
_
,
_0
{}),
tCrA_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
}
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
if
(
!
A_in_regs
)
{
copy
(
smem_t
hr
_copy_A
,
tCsA
(
_
,
_
,
i
+
1
),
tCrA_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
B_in_regs
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_A
,
tCsA
(
_
,
_
,
i
+
1
),
tCrA_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
...
...
@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
>
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
inline
__device__
void
gemm_A_in_regs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_thr_copy_B
)
{
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
copy
(
smem_t
hr
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
cute
::
copy
(
smem_t
iled
_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
...
...
@@ -319,7 +322,7 @@ void cp_async_wait() {
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
t
hr
_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
inline
__device__
void
copy
(
TiledCopy
t
iled
_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
...
...
@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
copy
(
t
hr
_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
cute
::
copy
(
t
iled
_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
clear
(
D
(
_
,
m
,
k
));
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
clear
(
D
(
_
,
m
,
_
));
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
...
...
@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(t
hr
_copy, S(_, m, _), D(_, m, _));
// copy(t
iled
_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
...
...
@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(t
hr
_copy, S(_, m, k), D(_, m, k));
// copy(t
iled
_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
...
...
tests/test_flash_attn.py
View file @
3524e13c
...
...
@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
#
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
#
@pytest.mark.parametrize('causal', [False, True])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
#
@pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
#
@pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
#
@pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
32
,
64
,
96
,
128
,
160
,
192
])
#
@pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
128
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment