Commit 48df84a4 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Re-arrange template arguments for blockwise copy

parent f8510368
...@@ -111,8 +111,6 @@ struct GridwisePermute ...@@ -111,8 +111,6 @@ struct GridwisePermute
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = detail::Block2TileMap<HPerBlock, WPerBlock, InGridDesc>; using DefaultBlock2TileMap = detail::Block2TileMap<HPerBlock, WPerBlock, InGridDesc>;
...@@ -152,8 +150,6 @@ struct GridwisePermute ...@@ -152,8 +150,6 @@ struct GridwisePermute
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize()); p_out_global, out_grid_desc.GetElementSpaceSize());
const auto loop_step_index = make_multi_index(1, 0, 0);
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -169,20 +165,22 @@ struct GridwisePermute ...@@ -169,20 +165,22 @@ struct GridwisePermute
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize()); static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize());
using SliceLengths = Sequence<1, HPerBlock, WPerBlock>; using SliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>; using ABlockTransferThreadClusterLengths = Sequence<1, 16, BlockSize / 16>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>; using ABlockTransferThreadClusterArrangeOrder = Sequence<0, 1, 2>;
using ABlockTransferSrcAccessOrder = Sequence<0, 1, 2>; using ABlockTransferAccessOrder = Sequence<0, 1, 2>;
using ABlockTransferDstAccessOrder = Sequence<0, 1, 2>;
constexpr index_t ABlockTransferSrcVectorDim = 2; constexpr index_t ABlockTransferSrcVectorDim = 2;
constexpr index_t ABlockTransferDstVectorDim = 2; constexpr index_t ABlockTransferDstVectorDim = 1;
constexpr index_t ABlockTransferSrcScalarPerVector = 1; constexpr index_t ABlockTransferSrcScalarPerVector = 1;
constexpr index_t ABlockTransferDstScalarPerVector = 1; constexpr index_t ABlockTransferDstScalarPerVector = 1;
using ck::tensor_operation::element_wise::PassThrough;
auto in_global_load = auto in_global_load =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
SliceLengths, SliceLengths,
ABlockTransferThreadClusterLengths, ABlockTransferThreadClusterLengths,
...@@ -191,22 +189,22 @@ struct GridwisePermute ...@@ -191,22 +189,22 @@ struct GridwisePermute
InDataType, InDataType,
decltype(in_grid_desc), decltype(in_grid_desc),
decltype(in_block_desc), decltype(in_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferAccessOrder,
ABlockTransferDstAccessOrder, ABlockTransferAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferDstVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector, ABlockTransferSrcScalarPerVector,
1, 1,
1, 1,
true, true,
true>( true>(
in_grid_desc, in_grid_desc,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
ck::tensor_operation::element_wise::PassThrough{}, PassThrough{},
in_block_desc, in_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); PassThrough{});
auto in_grid_desc_tranformed = transform_tensor_descriptor( auto in_grid_desc_tranformed = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
...@@ -216,45 +214,47 @@ struct GridwisePermute ...@@ -216,45 +214,47 @@ struct GridwisePermute
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1< auto out_global_store =
ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, HPerBlock, WPerBlock>, // SliceLengths SliceLengths,
ABlockTransferThreadClusterLengths, ABlockTransferThreadClusterLengths,
Sequence<0, 1, 2>, // ABlockTransferThreadClusterArrangeOrder ABlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
OutDataType, OutDataType,
decltype(in_block_desc), decltype(in_block_desc),
decltype(in_grid_desc_tranformed), decltype(in_grid_desc_tranformed),
Sequence<0, 1, 2>, // ABlockTransferSrcAccessOrder ABlockTransferAccessOrder,
Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder ABlockTransferAccessOrder,
1, // ABlockTransferSrcVectorDim ABlockTransferSrcVectorDim,
1, // ABlockTransferDstVectorDim ABlockTransferDstVectorDim,
1, // ABlockTransferSrcScalarPerVector ABlockTransferSrcScalarPerVector,
1, // ABlockTransferDstScalarPerVector ABlockTransferDstScalarPerVector,
1, 1,
1, 1,
true, true,
true>(in_block_desc, true>(
make_multi_index(0, 0, 0), in_block_desc,
ck::tensor_operation::element_wise::PassThrough{}, make_multi_index(0, 0, 0),
in_grid_desc_tranformed, PassThrough{},
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), in_grid_desc_tranformed,
elementwise_op); make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
index_t num_iter = in_grid_desc.GetLength(I0);
const auto loop_step = make_multi_index(1, 0, 0);
index_t num_iter = in_grid_desc.GetLength(I0);
do do
{ {
in_global_load.Run(in_grid_desc, in_global_buf, in_block_desc, in_block_buf, I0); in_global_load.Run(in_grid_desc, in_global_buf, in_block_desc, in_block_buf, I0);
in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step_index); in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step);
out_global_store.Run( out_global_store.Run(
in_block_desc, in_block_buf, in_grid_desc_tranformed, out_global_buf, I0); in_block_desc, in_block_buf, in_grid_desc_tranformed, out_global_buf, I0);
out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step_index); out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment