Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dbd79237
"vscode:/vscode.git/clone" did not exist on "b85bb0753ee9bfde6ba4e9217203f8d820aef4c2"
Commit
dbd79237
authored
Aug 13, 2023
by
Tri Dao
Browse files
Prepare for Cutlass 3.2
parent
c5e87b11
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
21 deletions
+40
-21
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+29
-17
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+11
-4
No files found.
csrc/flash_attn/src/kernel_traits.h
View file @
dbd79237
...
@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ
{},
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using
SmemLayoutAtomVtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomVtransposed
=
decltype
(
using
SmemLayoutAtomVtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomVtransposedNoSwizzle
{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposed
{},
SmemLayoutAtomVtransposed
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
// And the strides don't matter?
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
SmemLayoutVtransposed
{}.
layout_fn
());
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposedNoSwizzle
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using
SmemLayoutAtomO
=
decltype
(
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -223,16 +226,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV
{},
SmemLayoutAtomKV
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
using
SmemLayoutAtomKtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomKtransposed
=
decltype
(
using
SmemLayoutAtomKtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomKtransposedNoSwizzle
{}));
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
{}));
using
SmemLayoutKtransposed
=
decltype
(
tile_to_shape
(
using
SmemLayoutKtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposed
{},
SmemLayoutAtomKtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
// And the strides don't matter?
using
SmemLayoutKtransposedNoSwizzle
=
decltype
(
SmemLayoutKtransposed
{}.
layout_fn
());
using
SmemLayoutKtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
...
@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base {
...
@@ -250,24 +256,30 @@ struct Flash_bwd_kernel_traits : public Base {
using
SmemLayoutPdS
=
decltype
(
tile_to_shape
(
using
SmemLayoutPdS
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdS
{},
SmemLayoutAtomPdS
{},
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kBlockN
>
{})));
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kBlockN
>
{})));
using
SmemLayoutAtomPdStransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kPBlockN
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kPBlockN
>>>
;
using
SmemLayoutAtomPdStransposed
=
decltype
(
using
SmemLayoutAtomPdStransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzlePdS
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzlePdS
,
3
,
3
>
{},
SmemLayoutAtomPdStransposedNoSwizzle
{}));
Layout
<
Shape
<
Int
<
kPBlockN
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kPBlockN
>>>
{}));
using
SmemLayoutPdStransposed
=
decltype
(
tile_to_shape
(
using
SmemLayoutPdStransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposed
{},
SmemLayoutAtomPdStransposed
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutPdStransposedNoSwizzle
=
decltype
(
SmemLayoutPdStransposed
{}.
layout_fn
());
using
SmemLayoutPdStransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposedNoSwizzle
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using
SmemCopyAtomPdS
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomPdS
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemLayoutAtomQdOtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomQdOtransposed
=
decltype
(
using
SmemLayoutAtomQdOtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomQdOtransposedNoSwizzle
{}));
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
{}));
using
SmemLayoutQdOtransposed
=
decltype
(
tile_to_shape
(
using
SmemLayoutQdOtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposed
{},
SmemLayoutAtomQdOtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutQdOtransposedNoSwizzle
=
decltype
(
SmemLayoutQdOtransposed
{}.
layout_fn
());
using
SmemLayoutQdOtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using
SmemLayoutAtomdKV
=
decltype
(
using
SmemLayoutAtomdKV
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
...
...
csrc/flash_attn/src/utils.h
View file @
dbd79237
...
@@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
...
@@ -228,7 +228,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
return
make_layout
(
make_layout
(
get
<
0
,
1
>
(
l
),
get
<
1
>
(
l
)),
make_layout
(
get
<
0
,
0
>
(
l
),
get
<
2
>
(
l
)));
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return
make_layout
(
make_layout
(
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
l
)),
make_layout
(
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
2
>
(
l
)));
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
...
@@ -244,9 +247,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
1
,
0
>
(
l
),
get
<
0
,
0
>
(
l
),
get
<
1
,
1
,
0
>
(
l
)),
// TD [2023-08-13]: Same error as above on Cutlass 3.2
get
<
0
,
1
>
(
l
),
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get
<
1
,
1
,
1
>
(
l
));
// get<0, 1>(l),
// get<1, 1, 1>(l));
return
make_layout
(
make_layout
(
get
<
0
>
(
get
<
1
>
(
l
)),
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
0
>
(
get
<
1
>
(
get
<
1
>
(
l
)))),
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
get
<
1
>
(
get
<
1
>
(
l
))));
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment