Commit 0c23d6fa authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Embed layout in the variable names

parent ad1a639b
...@@ -121,6 +121,25 @@ struct GridwisePermute ...@@ -121,6 +121,25 @@ struct GridwisePermute
I1)); I1));
} }
template <typename GridDesc>
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
{
constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(generate_tuple(
[&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 1>{}))),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return merged_desc;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto in_block_desc = GetInBlockDesc(); constexpr auto in_block_desc = GetInBlockDesc();
...@@ -175,6 +194,8 @@ struct GridwisePermute ...@@ -175,6 +194,8 @@ struct GridwisePermute
using ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::element_wise::PassThrough;
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
auto in_global_load = auto in_global_load =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
...@@ -185,7 +206,7 @@ struct GridwisePermute ...@@ -185,7 +206,7 @@ struct GridwisePermute
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
InDataType, InDataType,
decltype(in_grid_desc), decltype(in_grid_desc_n_h_w),
decltype(in_block_desc), decltype(in_block_desc),
ABlockTransferAccessOrder, ABlockTransferAccessOrder,
ABlockTransferAccessOrder, ABlockTransferAccessOrder,
...@@ -197,18 +218,20 @@ struct GridwisePermute ...@@ -197,18 +218,20 @@ struct GridwisePermute
1, 1,
true, true,
true>( true>(
in_grid_desc, in_grid_desc_n_h_w,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{}, PassThrough{},
in_block_desc, in_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}); PassThrough{});
auto in_grid_desc_tranformed = transform_tensor_descriptor( const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
in_grid_desc,
make_tuple(make_pass_through_transform(in_grid_desc.GetLength(I0)), const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
make_pass_through_transform(in_grid_desc.GetLength(I1)), out_grid_desc_n_w_h,
make_pass_through_transform(in_grid_desc.GetLength(I2))), make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
...@@ -223,7 +246,7 @@ struct GridwisePermute ...@@ -223,7 +246,7 @@ struct GridwisePermute
InDataType, InDataType,
OutDataType, OutDataType,
decltype(in_block_desc), decltype(in_block_desc),
decltype(in_grid_desc_tranformed), decltype(out_grid_desc_n_h_w),
ABlockTransferAccessOrder, ABlockTransferAccessOrder,
ABlockTransferAccessOrder, ABlockTransferAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
...@@ -237,22 +260,22 @@ struct GridwisePermute ...@@ -237,22 +260,22 @@ struct GridwisePermute
in_block_desc, in_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}, PassThrough{},
in_grid_desc_tranformed, out_grid_desc_n_h_w,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op); elementwise_op);
const auto loop_step = make_multi_index(1, 0, 0); const auto loop_step = make_multi_index(1, 0, 0);
index_t num_iter = in_grid_desc.GetLength(I0); index_t num_iter = in_grid_desc_n_h_w.GetLength(I0);
do do
{ {
in_global_load.Run(in_grid_desc, in_global_buf, in_block_desc, in_block_buf, I0); in_global_load.Run(in_grid_desc_n_h_w, in_global_buf, in_block_desc, in_block_buf, I0);
in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step); in_global_load.MoveSrcSliceWindow(in_grid_desc_n_h_w, loop_step);
out_global_store.Run( out_global_store.Run(
in_block_desc, in_block_buf, in_grid_desc_tranformed, out_global_buf, I0); in_block_desc, in_block_buf, out_grid_desc_n_h_w, out_global_buf, I0);
out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step); out_global_store.MoveDstSliceWindow(out_grid_desc_n_h_w, loop_step);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment