Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1c5b049d
Commit
1c5b049d
authored
Oct 19, 2023
by
Adam Osewski
Browse files
Add Cshuffle and results write to GMEM.
parent
92eb966d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
310 additions
and
324 deletions
+310
-324
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+29
-21
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+281
-303
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
1c5b049d
...
...
@@ -53,6 +53,7 @@ template <typename GridwiseGemm,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
DsDataType
,
typename
Block2ETileMapKSplit
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
...
...
@@ -124,12 +125,10 @@ __global__ void
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto
gridwise_gemm
=
GridwiseGemm
();
auto
&
results_buffer
=
gridwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -159,7 +158,6 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
// Store partial results to auxilliary workspace.
gridwise_gemm
.
StorePartials
(
p_workspace
);
}
...
...
@@ -182,27 +180,33 @@ __global__ void
gridwise_gemm
.
AccumulatePartials
(
p_workspace
,
flag_v
);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// TODO do fusion, cshuffle and store results to GMEM
// gridwise_gemm.RunWrite(results_buffer,
// p_c_grid,
// M,
// N,
// K,
// StrideA,
// StrideB,
// StrideC,
// MPadded,
// NPadded,
// KPadded,
// K0,
// k_batch,
// static_cast<void*>(p_shared),
// b2c_tile_map);
const
auto
p_e_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
);
const
auto
stride_e
=
gemm_desc_ptr
[
group_id
].
StrideE
;
const
auto
stride_ds
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
gridwise_gemm
.
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
static_cast
<
void
*
>(
p_shared
),
M
,
N
,
stride_ds
,
stride_e
,
cde_element_op
,
b2c_tile_map
);
}
else
{
...
...
@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType
,
BDataType
,
EDataType
,
DsDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType
,
BDataType
,
EDataType
,
DsDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType
,
BDataType
,
EDataType
,
DsDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
1c5b049d
...
...
@@ -44,6 +44,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -696,6 +697,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_2_etile_map
);
}
// TODO Need to do CShuffle already here:
__device__
void
StorePartials
(
void
*
__restrict__
p_workspace
)
{
// M0 = grid_size
...
...
@@ -849,15 +851,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
// if (threadIdx.x == 0)
// {
// printf("w_grid_desc_m0_n0_m1_n1: [%d, %d, %d, %d]\n",
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I2),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I3));
// }
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
...
@@ -981,300 +974,285 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// typename DsDataType_,
// typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename CDEElementwiseOperation_,
// typename Block2ETileMap>
// __device__ void RunWrite(CThreadBufer c_thread_buf,
// const EDataType* __restrict__ p_workspace,
// DsGridPointer p_ds_grid,
// EDataType* __restrict__ p_e_grid,
// void* __restrict__ p_shared,
// const index_t KBatch,
// const CDEElementwiseOperation_& cde_element_op,
// const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// const Block2ETileMap& block_2_etile_map)
// {
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// const auto ds_grid_buf = generate_tuple(
// [&](auto i) {
// return make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_ds_grid[i],
// ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
// },
// Number<NumDTensor_>{});
// auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
// });
// const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// // using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
// // remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// // DsGridDesc_M_N{}))>;
// // DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
// ds_grid_desc_mblock_mperblock_nblock_nperblock;
// // static_for<0, NumDTensor, 1>{}([&](auto j) {
// // ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// // });
// // const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// // shuffle C and write out
// static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
// "wrong!");
// constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// // TODO: hacky, fix it!
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// // TODO: hacky, fix it!
// // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
// constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
// constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
// constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
// constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
// constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
// GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
// make_tuple(
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
// M1, // M1 = MWave
// M2, // M2 * M3 * M4 = MPerXdl
// M3,
// M4)),
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
// N1, // N1 = NWave
// N2))), // N2 = NPerXdl
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(
// Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// // calculate origin of thread output tensor on global memory
// // blockwise GEMM c matrix starting index
// const auto c_thread_mtx_on_block =
// blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
// const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
// const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
// make_tuple(Sequence<0, 1, 2, 3, 4>{}),
// make_tuple(Sequence<0>{}));
// const auto m_thread_data_on_block_idx =
// m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
// make_multi_index(m_thread_data_on_block));
// const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
// make_tuple(Sequence<0, 1, 2>{}),
// make_tuple(Sequence<0>{}));
// const auto n_thread_data_on_block_idx =
// n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
// make_multi_index(n_thread_data_on_block));
// // shuffle: threadwise copy C from VGPR to LDS
// auto c_thread_copy_vgpr_to_lds =
// ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
// CShuffleDataType,
// decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// ck::tensor_operation::element_wise::PassThrough,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// I1,
// I1,
// M2,
// I1,
// M4,
// I1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// 7,
// 1,
// InMemoryDataOperationEnum::Set,
// 1,
// true>{
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// make_multi_index(0,
// 0,
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]),
// ck::tensor_operation::element_wise::PassThrough{}};
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_desc_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
// Number<NumDTensor_>{}));
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_buf_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_buf),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_buf[i]; },
// Number<NumDTensor_>{}));
// // tuple of starting index of C/Ds blockwise copy
// const auto idx_c_ds_block_begin = container_concat(
// make_tuple(make_multi_index(0, 0, 0, 0)),
// generate_tuple(
// [&](auto) {
// return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
// },
// Number<NumDTensor_>{}));
// // space filling curve for threadwise C in VGPR before shuffle
// constexpr auto sfc_c_vgpr =
// SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// 1,
// 1,
// M2,
// 1,
// M4,
// 1>>{};
// // space filling curve for shuffled blockwise C/D/E
// constexpr auto sfc_cde_block =
// SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
// Sequence<0, 2, 1, 3>,
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// // blockwise copy C/D/E between LDS and global
// auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// ThisThreadBlock,
// decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
// Tuple<EDataType>,
// decltype(c_ds_desc_refs),
// decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
// CDEElementwiseOperation_,
// Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// // Sequence support
// // arbitray type
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
// Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// 3, // index_t VectorDim,
// CDEShuffleBlockTransferScalarPerVector_NPerBlock,
// sequence_merge_t<
// Sequence<true>,
// uniform_sequence_gen_t<NumDTensor_,
// false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
// Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
// {c_ds_desc_refs,
// idx_c_ds_block_begin,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
// cde_element_op};
// static_for<0, num_access, 1>{}([&](auto access_id) {
// // make sure it's safe to write to LDS
// block_sync_lds();
// // each thread write its data from VGPR to LDS
// c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// c_thread_buf,
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// c_shuffle_block_buf);
// // make sure it's safe to read from LDS
// block_sync_lds();
// // each block copy its data from LDS to global
// cde_block_copy_lds_and_global.Run(
// c_ds_desc_refs,
// c_ds_buf_refs,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// tie(e_grid_buf));
// if constexpr(access_id < num_access - 1)
// {
// constexpr auto cde_lds_and_global_step =
// sfc_cde_block.GetForwardStep(access_id);
// // move on Ds
// static_for<0, NumDTensor_, 1>{}([&](auto i) {
// cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// });
// // move on E
// cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// I0,
// cde_lds_and_global_step);
// }
// });
// }
template
<
typename
Block2ETileMap
>
__device__
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2ETileMap
&
block_2_etile_map
)
{
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
StrideDs
[
j
]);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
// shuffle C and write out
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// divide block work by [M, N, K]
const
auto
block_work_idx
=
block_2_etile_map
.
GetBottomIndex
();
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
// space filling curve for threadwise C in VGPR before shuffle
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make
// Sequence support
// arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde_element_op
};
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment