Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8d1b169e
Commit
8d1b169e
authored
Jan 12, 2024
by
Tri Dao
Browse files
Simplify SmemLayoutVtransposed in kernel_traits.h
parent
c9861a03
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
83 deletions
+41
-83
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+2
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+18
-18
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+14
-49
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+7
-14
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
8d1b169e
...
@@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Layout p_l = tPrP.layout();
// Layout p_l = tPrP.layout();
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
// flash::gemm_
A_in_reg
s(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// flash::gemm_
r
s(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_
A_in_reg
s(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
// flash::gemm_
r
s(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
flash
::
gemm
(
acc_dv
,
tdVrPt
,
tdVrdO
,
tdVsPt
,
tdVsdOt
,
tiled_mma_dkv
,
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
smem_tiled_copy_PdSt
,
smem_tiled_copy_QdOt
,
smem_thr_copy_PdSt
,
smem_thr_copy_QdOt
);
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
8d1b169e
...
@@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
}
}
// if (cute::thread0()) { print(tOrP); }
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_
A_in_reg
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_
r
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
// This check is at the end of the loop since we always have at least 1 iteration
...
@@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
flash
::
gemm_
A_in_reg
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_
r
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
}
// Epilogue
// Epilogue
...
@@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_
A_in_reg
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_
r
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
// This check is at the end of the loop since we always have at least 1 iteration
...
@@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_
A_in_reg
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
flash
::
gemm_
r
s
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
}
// Epilogue
// Epilogue
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
8d1b169e
...
@@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ
{},
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
using
SmemLayoutAtomVtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
using
SmemLayoutVtransposed
=
decltype
(
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
composition
(
SmemLayoutKV
{},
make_layout
(
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutAtomVtransposed
=
decltype
(
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomVtransposedNoSwizzle
{}));
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposed
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposedNoSwizzle
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using
SmemLayoutAtomO
=
decltype
(
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
@@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV
{},
SmemLayoutAtomKV
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
using
SmemLayoutAtomKtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
using
SmemLayoutKtransposed
=
decltype
(
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
composition
(
SmemLayoutKV
{},
make_layout
(
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutAtomKtransposed
=
decltype
(
using
SmemLayoutKtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutKtransposed
{}));
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomKtransposedNoSwizzle
{}));
using
SmemLayoutKtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutKtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
...
@@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base {
using
SmemLayoutPdS
=
decltype
(
tile_to_shape
(
using
SmemLayoutPdS
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdS
{},
SmemLayoutAtomPdS
{},
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kBlockN
>
{})));
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kBlockN
>
{})));
using
SmemLayoutAtomPdStransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kPBlockN
>
,
Int
<
kBlockM
>>
,
using
SmemLayoutPdStransposed
=
decltype
(
Stride
<
_1
,
Int
<
kPBlockN
>>>
;
composition
(
SmemLayoutPdS
{},
make_layout
(
Shape
<
Int
<
kBlockN
>
,
Int
<
kBlockM
>>
{},
GenRowMajor
{})));
using
SmemLayoutAtomPdStransposed
=
decltype
(
using
SmemLayoutPdStransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutPdStransposed
{}));
composition
(
Swizzle
<
kSwizzlePdS
,
3
,
3
>
{},
SmemLayoutAtomPdStransposedNoSwizzle
{}));
using
SmemLayoutPdStransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposed
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutPdStransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposedNoSwizzle
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using
SmemCopyAtomPdS
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomPdS
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemLayoutAtomQdOtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockM
>>
,
using
SmemLayoutQdOtransposed
=
decltype
(
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
composition
(
SmemLayoutQdO
{},
make_layout
(
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockM
>>
{},
GenRowMajor
{})));
using
SmemLayoutAtomQdOtransposed
=
decltype
(
using
SmemLayoutQdOtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutQdOtransposed
{}));
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomQdOtransposedNoSwizzle
{}));
using
SmemLayoutQdOtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutQdOtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using
SmemLayoutAtomdKV
=
decltype
(
using
SmemLayoutAtomdKV
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
...
csrc/flash_attn/src/utils.h
View file @
8d1b169e
...
@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
...
@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
inline
__device__
void
gemm_
A_in_reg
s
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
inline
__device__
void
gemm_
r
s
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
...
@@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
...
@@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
return
make_layout
(
make_layout
(
get
<
0
,
1
>
(
l
),
get
<
1
>
(
l
)),
make_layout
(
get
<
0
,
0
>
(
l
),
get
<
2
>
(
l
)));
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return
make_layout
(
make_layout
(
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
l
)),
make_layout
(
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
2
>
(
l
)));
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
...
@@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
// TD [2023-08-13]: Same error as above on Cutlass 3.2
return
make_layout
(
make_layout
(
get
<
1
,
0
>
(
l
),
get
<
0
,
0
>
(
l
),
get
<
1
,
1
,
0
>
(
l
)),
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get
<
0
,
1
>
(
l
),
// get<0, 1>(l),
get
<
1
,
1
,
1
>
(
l
));
// get<1, 1, 1>(l));
return
make_layout
(
make_layout
(
get
<
0
>
(
get
<
1
>
(
l
)),
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
0
>
(
get
<
1
>
(
get
<
1
>
(
l
)))),
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
get
<
1
>
(
get
<
1
>
(
l
))));
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment