Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa066d60
Commit
fa066d60
authored
Aug 09, 2023
by
letaoqin
Browse files
gridwise change to multiple D
parent
ee275d4d
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
115 deletions
+206
-115
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+89
-57
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
...n/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+117
-58
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
fa066d60
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
fa066d60
...
@@ -25,6 +25,7 @@ namespace ck {
...
@@ -25,6 +25,7 @@ namespace ck {
*
*
*/
*/
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
D0sDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
FloatGemm
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
...
@@ -39,6 +40,7 @@ template <typename FloatAB,
...
@@ -39,6 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0sGridDesc_M_N
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
ZGridDesc_M_N
,
...
@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
using
DDataType
=
FloatAB
;
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
()
;
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
static
constexpr
auto
MakeD0sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
static_cast
<
const
D0DataType
*>
(
nullptr
);
},
Number
<
NumD0Tensor
>
{});
}
// D0 desc for source in blockwise copy
template
<
typename
D0GridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
// D0s desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0sGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
using
D0sGridPointer
=
decltype
(
MakeD0sGridPointer
());
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
...
@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
typename
C0MatrixMask
>
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
D0sGridPointer
p_d0s_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
DDataType
*
__restrict__
p_d_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
...
@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
...
@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// bias (d matrix)
// bias (d matrix)
constexpr
auto
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
d
0
_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
I1
,
// NBlockId
m0
,
// MRepeat
m0
,
// MRepeat
...
@@ -919,36 +968,49 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -919,36 +968,49 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
// RegisterNum
n4
));
// RegisterNum
auto
d_threadwise_copy_globla_vgpr
=
auto
d0s_threadwise_copy
=
generate_tuple
(
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
[
&
](
auto
i
)
{
DDataType
,
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
decltype
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
return
ThreadwiseTensorSliceTransfer_v2
<
decltype
(
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
D0DataType
,
Sequence
<
I1
,
// MBlockId
D0DataType
,
I1
,
// NBlockID
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
m0
,
// MRepeat
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
n0
,
// NRepeat
Sequence
<
I1
,
// MBlockId
m1
,
// MWaveId
I1
,
// NBlockID
n1
,
// NWaveId
m0
,
// MRepeat
m2
,
// MPerXdl
n0
,
// NRepeat
n2
,
// NGroupNum
m1
,
// MWaveId
n3
,
// NInputNum
n1
,
// NWaveId
n4
>
,
m2
,
// MPerXdl
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
n2
,
// NGroupNum
9
,
n3
,
// NInputNum
D0BlockTransferSrcScalarPerVector
,
n4
>
,
1
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
false
>
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
9
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
D0BlockTransferSrcScalarPerVector
,
0
,
// NBlockId
1
,
0
,
// mrepeat
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
0
,
// nrepeat
make_multi_index
(
block_work_idx_m
,
// MBlockId
wave_id
[
I0
],
// MWaveId
0
,
// NBlockId
wave_id
[
I1
],
// NWaveId
0
,
// mrepeat
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// nrepeat
0
,
// group
wave_id
[
I0
],
// MWaveId
wave_m_n_id
[
I0
],
// NInputIndex
wave_id
[
I1
],
// NWaveId
0
));
// register number
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
},
Number
<
NumD0Tensor
>
{});
const
auto
d0s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0s_grid
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
// z is random number matrix for dropout verify
// z is random number matrix for dropout verify
//
//
...
@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
// add bias
if
(
p_d_grid
)
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
{
// get register
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
p_d_grid
,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
D
0
DataType
,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
d
0
_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
true
>
d_thread_buf
;
d0_thread_buf
;
d_threadwise_copy_globla_vgpr
.
Run
(
// load data from global
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d_grid_buf
,
d0s_grid_buf
[
i
],
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
);
d0_thread_buf
);
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
// acc add bias
acc_thread_buf
(
i
)
+=
d_thread_buf
[
i
];
});
d_threadwise_copy_globla_vgpr
.
MoveSrcSliceWindow
(
// acc add bias
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}(
[
&
](
auto
j
)
{
acc_thread_buf
(
j
)
+=
d0_thread_buf
[
j
];
});
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
);
// softmax
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment