Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
957ab734
Commit
957ab734
authored
Jun 08, 2023
by
guangzlu
Browse files
removed global shuffle parameters
parent
a3c14b5f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
191 deletions
+1
-191
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-191
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
957ab734
...
@@ -143,87 +143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -143,87 +143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
}
//// Z desc for source in blockwise copy
//__host__ __device__ static constexpr auto GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() ////=>
//for z use
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
// return make_naive_tensor_descriptor_packed(
// make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// Number<MPerXdl>{}, Number<N3>{}, Number<N4>{}, Number<N5>{}));
//}
// C shuffle desc for source in gridwise copy
__host__
__device__
static
constexpr
auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use to shuffle
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
/
N5
,
N5
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
9
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
10
>
{}));
// 0247,13568
}
// using ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 = remove_cvref_t<decltype(
// GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Z shuffle desc for source in blockwise copy
//__host__ __device__ static constexpr auto
// GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4(const
// ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4& z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4) ////=> for z
// use to shuffle
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
//
// constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
// transform_tensor_descriptor(
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// make_freeze_transform(Number<MXdlPerWave>{}),
// make_freeze_transform(Number<NXdlPerWave>{}),
// make_freeze_transform(Number<Gemm0MWaves>{}),
// make_freeze_transform(Number<Gemm0NWaves>{}),
// make_unmerge_transform(make_tuple(Number<MPerXdl / N5>{}, Number<N5>{})),
// make_freeze_transform(Number<N3>{}),
// make_freeze_transform(Number<N4>{}),
// make_freeze_transform(Number<N5>{})),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},Sequence<4>{},
// Sequence<5>{}, Sequence<6>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{},
// Sequence<2>{}, Sequence<3>{},Sequence<4,7>{}, Sequence<5>{}, Sequence<6>{},
// Sequence<8>{}));
// return z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// //return make_naive_tensor_descriptor_packed(
// // make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// // Number<MPerXdl / N5>{}, Number<N3>{}, Number<N4>{}, Number<N5>{},
// Number<N5>{}));
//}
__device__
static
auto
GetGemm0WaveIdx
()
__device__
static
auto
GetGemm0WaveIdx
()
{
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
@@ -462,9 +381,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -462,9 +381,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
using
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -525,8 +441,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -525,8 +441,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
&
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -971,20 +885,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -971,20 +885,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4
,
// m1 4
n4
,
// m1 4
I1
));
// n2 1
I1
));
// n2 1
// constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = //for gridwise
// copy
// make_naive_tensor_descriptor_packed(make_tuple(I1,
// I1,
// m0, //
// n0, //
// m1, //
// n1, //
// m2, // m0 1
// n2, // n0 4
// n3, // n1 1
// n4, // m1 4
// I1)); // n2 1
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
...
@@ -1060,26 +960,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1060,26 +960,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
// auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_z_grid,
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
ignore
=
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
;
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ZDataType
*>
(
p_shared
),
static_cast
<
ZDataType
*>
(
p_shared
),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
// ignore = z_block_buf;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// printf("z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize() is
// %ld \n", z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
//
//}
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
...
@@ -1109,59 +993,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1109,59 +993,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// }
//}
//}
/*
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true >{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, //
MBlockId 0, // NBlockId 0, // mrepeat 0, //
nrepeat wave_id[I0], // MWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), //
MPerXdl 0, // group wave_m_n_id[I0], // NInputIndex 0,
wave_m_n_id[I1] % 4)};
*/
auto
z_tmp_thread_copy_vgpr_to_lds
=
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
...
@@ -1402,13 +1233,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1402,13 +1233,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
z_block_buf
);
// z_tmp_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_tmp_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_tmp_buf);
block_sync_lds
();
block_sync_lds
();
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
...
@@ -1420,17 +1244,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1420,17 +1244,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
z_tenor_buffer
);
// z_shuffle_thread_copy_global_to_vgpr.Run(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_grid_tmp_buf,
// z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
false
>(
acc_thread_buf
,
false
>(
acc_thread_buf
,
z_tenor_buffer
);
z_tenor_buffer
);
// ignore = z_tenor_buffer;
// ignore = z_tenor_buffer;
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1441,14 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1441,14 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_sync_lds
();
block_sync_lds
();
// z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment