Commit 48df84a4 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Re-arrange template arguments for blockwise copy

parent f8510368
......@@ -111,8 +111,6 @@ struct GridwisePermute
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = detail::Block2TileMap<HPerBlock, WPerBlock, InGridDesc>;
......@@ -152,8 +150,6 @@ struct GridwisePermute
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
const auto loop_step_index = make_multi_index(1, 0, 0);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -169,20 +165,22 @@ struct GridwisePermute
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize());
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>;
using ABlockTransferSrcAccessOrder = Sequence<0, 1, 2>;
using ABlockTransferDstAccessOrder = Sequence<0, 1, 2>;
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>;
using ABlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t ABlockTransferSrcVectorDim = 2;
constexpr index_t ABlockTransferDstVectorDim = 2;
constexpr index_t ABlockTransferDstVectorDim = 1;
constexpr index_t ABlockTransferSrcScalarPerVector = 1;
constexpr index_t ABlockTransferDstScalarPerVector = 1;
using ck::tensor_operation::element_wise::PassThrough;
auto in_global_load =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
......@@ -191,22 +189,22 @@ struct GridwisePermute
InDataType,
decltype(in_grid_desc),
decltype(in_block_desc),
ABlockTransferSrcAccessOrder,
ABlockTransferDstAccessOrder,
ABlockTransferAccessOrder,
ABlockTransferAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcVectorDim,
ABlockTransferDstVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector,
ABlockTransferSrcScalarPerVector,
1,
1,
true,
true>(
in_grid_desc,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
ck::tensor_operation::element_wise::PassThrough{},
PassThrough{},
in_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
PassThrough{});
auto in_grid_desc_tranformed = transform_tensor_descriptor(
in_grid_desc,
......@@ -216,45 +214,47 @@ struct GridwisePermute
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, HPerBlock, WPerBlock>, // SliceLengths
ABlockTransferThreadClusterLengths,
Sequence<0, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
InDataType,
OutDataType,
decltype(in_block_desc),
decltype(in_grid_desc_tranformed),
Sequence<0, 1, 2>, // ABlockTransferSrcAccessOrder
Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferDstVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector
1,
1,
true,
true>(in_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
in_grid_desc_tranformed,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
index_t num_iter = in_grid_desc.GetLength(I0);
auto out_global_store =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation,
PassThrough,
InMemoryDataOperationEnum::Set,
SliceLengths,
ABlockTransferThreadClusterLengths,
ABlockTransferThreadClusterArrangeOrder,
InDataType,
OutDataType,
decltype(in_block_desc),
decltype(in_grid_desc_tranformed),
ABlockTransferAccessOrder,
ABlockTransferAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferDstVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector,
1,
1,
true,
true>(
in_block_desc,
make_multi_index(0, 0, 0),
PassThrough{},
in_grid_desc_tranformed,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
const auto loop_step = make_multi_index(1, 0, 0);
index_t num_iter = in_grid_desc.GetLength(I0);
do
{
in_global_load.Run(in_grid_desc, in_global_buf, in_block_desc, in_block_buf, I0);
in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step_index);
in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step);
out_global_store.Run(
in_block_desc, in_block_buf, in_grid_desc_tranformed, out_global_buf, I0);
out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step_index);
out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step);
} while(--num_iter);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment