Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8f4d82cf
Commit
8f4d82cf
authored
Jan 20, 2024
by
Tri Dao
Browse files
Update cutlass to v3.4.0
parent
395e5a0d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
16 deletions
+14
-16
csrc/cutlass
csrc/cutlass
+1
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+9
-8
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+0
-1
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+4
-6
No files found.
cutlass
@
751eb9a8
Compare
a75b4ac4
...
751eb9a8
Subproject commit
a
75
b4ac483166189a45290783cb0a18af5ff0ea5
Subproject commit 75
1eb9a8859ac36bfc77551f9e4a957c31a5a8b1
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
8f4d82cf
...
@@ -32,11 +32,12 @@ CUTE_HOST_DEVICE
...
@@ -32,11 +32,12 @@ CUTE_HOST_DEVICE
auto
auto
make_tiled_copy_B_warpcontiguousN
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
make_tiled_copy_B_warpcontiguousN
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
constexpr
int
TileShape_N
=
decltype
(
tiled_mma
.
template
tile_size_mnk
<
1
>())
::
value
;
constexpr
int
TileShape_K
=
decltype
(
tiled_mma
.
template
tile_size_mnk
<
2
>())
::
value
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_N
=
decltype
(
size
<
1
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
AtomShape_N
=
decltype
(
size
<
1
>
(
AtomShape_MNK
{}))
::
value
;
// Divide by 2 because right now we always use 2 for the ValLayout
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr
int
kNWarpsN
=
decltype
(
size
<
1
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_N
/
2
;
constexpr
int
kNWarpsN
=
TileShape_N
/
AtomShape_N
/
2
;
constexpr
int
MMAStride_N
=
MMA_N
*
AtomShape_N
*
2
;
constexpr
int
MMAStride_N
=
MMA_N
*
AtomShape_N
*
2
;
// This gives the correct layout, idk why.
// This gives the correct layout, idk why.
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
...
@@ -45,7 +46,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
...
@@ -45,7 +46,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
// Stride<_1, _64, _8> >{},
// Stride<_1, _64, _8> >{},
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_N
>
,
Int
<
kNWarpsN
>
,
_2
>
,
// (8, 2, 2) or (8, 4, 2)
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_N
>
,
Int
<
kNWarpsN
>
,
_2
>
,
// (8, 2, 2) or (8, 4, 2)
Stride
<
_1
,
Int
<
MMAStride_N
>
,
_8
>
>
{},
// (1, 64, 8) or (1, 32, 8)
Stride
<
_1
,
Int
<
MMAStride_N
>
,
_8
>
>
{},
// (1, 64, 8) or (1, 32, 8)
make_layout
(
size
<
2
>
(
TileShape_
MN
K
{}))
)
;
make_layout
(
Int
<
TileShape_K
>
{}));
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutB_TV
(),
t
);
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutB_TV
(),
t
);
}
}
...
@@ -59,13 +60,14 @@ CUTE_HOST_DEVICE
...
@@ -59,13 +60,14 @@ CUTE_HOST_DEVICE
auto
auto
make_tiled_copy_C_warpcontiguousN
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
make_tiled_copy_C_warpcontiguousN
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
constexpr
int
TileShape_M
=
decltype
(
tiled_mma
.
template
tile_size_mnk
<
0
>())
::
value
;
constexpr
int
TileShape_N
=
decltype
(
tiled_mma
.
template
tile_size_mnk
<
1
>())
::
value
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_N
=
decltype
(
size
<
1
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
AtomShape_N
=
decltype
(
size
<
1
>
(
AtomShape_MNK
{}))
::
value
;
// Divide by 2 because right now we always use 2 for the ValLayout
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr
int
kNWarpsN
=
decltype
(
size
<
1
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_N
/
2
;
constexpr
int
kNWarpsN
=
TileShape_N
/
AtomShape_N
/
2
;
constexpr
int
MMAStride_N
=
MMA_N
*
AtomShape_N
*
2
;
constexpr
int
MMAStride_N
=
MMA_N
*
AtomShape_N
*
2
;
auto
t
=
make_tile
(
make_layout
(
size
<
0
>
(
TileShape_M
NK
{})
)
,
auto
t
=
make_tile
(
make_layout
(
Int
<
TileShape_M
>
{}),
Layout
<
Shape
<
Int
<
AtomShape_N
>
,
Int
<
kNWarpsN
>
,
_2
>
,
// (8, 2, 2) or (8, 4, 2)
Layout
<
Shape
<
Int
<
AtomShape_N
>
,
Int
<
kNWarpsN
>
,
_2
>
,
// (8, 2, 2) or (8, 4, 2)
Stride
<
_1
,
Int
<
MMAStride_N
>
,
_8
>
>
{});
// (1, 64, 8) or (1, 32, 8)
Stride
<
_1
,
Int
<
MMAStride_N
>
,
_8
>
>
{});
// (1, 64, 8) or (1, 32, 8)
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
...
@@ -90,8 +92,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -90,8 +92,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
// constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr
int
MMA_N_SdP
=
kBlockN
/
decltype
(
typename
Kernel_traits
::
TiledMmaSdP
{}.
template
tile_size_mnk
<
1
>())
::
value
;
constexpr
int
MMA_N_SdP
=
kBlockN
/
decltype
(
size
<
1
>
(
typename
Kernel_traits
::
TiledMmaSdP
::
TiledShape_MNK
{}))
::
value
;
constexpr
int
AtomLayoutMS
=
Kernel_traits
::
AtomLayoutMSdP
;
constexpr
int
AtomLayoutMS
=
Kernel_traits
::
AtomLayoutMSdP
;
constexpr
bool
Double_buffer
=
!
Kernel_traits
::
No_double_buffer
;
constexpr
bool
Double_buffer
=
!
Kernel_traits
::
No_double_buffer
;
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
8f4d82cf
...
@@ -41,7 +41,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -41,7 +41,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
flash
::
Dropout
dropout
(
std
::
get
<
0
>
(
seed_offset
),
std
::
get
<
1
>
(
seed_offset
),
params
.
p_dropout_in_uint8_t
,
flash
::
Dropout
dropout
(
std
::
get
<
0
>
(
seed_offset
),
std
::
get
<
1
>
(
seed_offset
),
params
.
p_dropout_in_uint8_t
,
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
8f4d82cf
...
@@ -32,10 +32,8 @@ struct Flash_kernel_traits {
...
@@ -32,10 +32,8 @@ struct Flash_kernel_traits {
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
,
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
,
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
>
;
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
#else
#else
using
MMA_Atom_Arch
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
MMA_Atom_Arch
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_2
>>
;
#endif
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
...
@@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base {
using
TiledMma
=
TiledMMA
<
using
TiledMma
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile
<
Int
<
16
*
kNWarps
>
,
_16
,
_16
>>
;
using
SmemLayoutAtomQ
=
decltype
(
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
@@ -197,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -197,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base {
using
TiledMmaSdP
=
TiledMMA
<
using
TiledMmaSdP
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutMSdP
>
,
Int
<
kNWarps
/
AtomLayoutMSdP
>
,
_1
>>
,
Layout
<
Shape
<
Int
<
AtomLayoutMSdP
>
,
Int
<
kNWarps
/
AtomLayoutMSdP
>
,
_1
>>
,
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile
<
Int
<
16
*
AtomLayoutMSdP
>
,
Int
<
16
*
kNWarps
/
AtomLayoutMSdP
>
,
_16
>>
;
using
TiledMmadKV
=
TiledMMA
<
using
TiledMmadKV
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutNdKV
>
,
Int
<
kNWarps
/
AtomLayoutNdKV
>
,
_1
>>
,
Layout
<
Shape
<
Int
<
AtomLayoutNdKV
>
,
Int
<
kNWarps
/
AtomLayoutNdKV
>
,
_1
>>
,
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile
<
Int
<
16
*
AtomLayoutNdKV
>
,
Int
<
16
*
kNWarps
/
AtomLayoutNdKV
>
,
_16
>>
;
using
TiledMmadQ
=
TiledMMA
<
using
TiledMmadQ
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutMdQ
>
,
Int
<
kNWarps
/
AtomLayoutMdQ
>
,
_1
>>
,
// 2x4x1 or 4x2x1 thread group
Layout
<
Shape
<
Int
<
AtomLayoutMdQ
>
,
Int
<
kNWarps
/
AtomLayoutMdQ
>
,
_1
>>
,
// 2x4x1 or 4x2x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile
<
Int
<
16
*
AtomLayoutMdQ
>
,
Int
<
16
*
kNWarps
/
AtomLayoutMdQ
>
,
_16
>>
;
using
SmemLayoutAtomQdO
=
decltype
(
using
SmemLayoutAtomQdO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment