"example/vscode:/vscode.git/clone" did not exist on "b00ae5df0b49891ce1bc5592e65663b0f9aad9ea"
Commit 0c23d6fa authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Embed layout in the variable names

parent ad1a639b
......@@ -121,6 +121,25 @@ struct GridwisePermute
I1));
}
template <typename GridDesc>
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
{
constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(generate_tuple(
[&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 1>{}))),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return merged_desc;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto in_block_desc = GetInBlockDesc();
......@@ -175,6 +194,8 @@ struct GridwisePermute
using ck::tensor_operation::element_wise::PassThrough;
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
auto in_global_load =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ElementwiseOperation,
......@@ -185,7 +206,7 @@ struct GridwisePermute
ABlockTransferThreadClusterArrangeOrder,
InDataType,
InDataType,
decltype(in_grid_desc),
decltype(in_grid_desc_n_h_w),
decltype(in_block_desc),
ABlockTransferAccessOrder,
ABlockTransferAccessOrder,
......@@ -197,18 +218,20 @@ struct GridwisePermute
1,
true,
true>(
in_grid_desc,
in_grid_desc_n_h_w,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{},
in_block_desc,
make_multi_index(0, 0, 0),
PassThrough{});
auto in_grid_desc_tranformed = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(in_grid_desc.GetLength(I0)),
make_pass_through_transform(in_grid_desc.GetLength(I1)),
make_pass_through_transform(in_grid_desc.GetLength(I2))),
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
out_grid_desc_n_w_h,
make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
......@@ -223,7 +246,7 @@ struct GridwisePermute
InDataType,
OutDataType,
decltype(in_block_desc),
decltype(in_grid_desc_tranformed),
decltype(out_grid_desc_n_h_w),
ABlockTransferAccessOrder,
ABlockTransferAccessOrder,
ABlockTransferSrcVectorDim,
......@@ -237,22 +260,22 @@ struct GridwisePermute
in_block_desc,
make_multi_index(0, 0, 0),
PassThrough{},
in_grid_desc_tranformed,
out_grid_desc_n_h_w,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
const auto loop_step = make_multi_index(1, 0, 0);
index_t num_iter = in_grid_desc.GetLength(I0);
index_t num_iter = in_grid_desc_n_h_w.GetLength(I0);
do
{
in_global_load.Run(in_grid_desc, in_global_buf, in_block_desc, in_block_buf, I0);
in_global_load.Run(in_grid_desc_n_h_w, in_global_buf, in_block_desc, in_block_buf, I0);
in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step);
in_global_load.MoveSrcSliceWindow(in_grid_desc_n_h_w, loop_step);
out_global_store.Run(
in_block_desc, in_block_buf, in_grid_desc_tranformed, out_global_buf, I0);
in_block_desc, in_block_buf, out_grid_desc_n_h_w, out_global_buf, I0);
out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step);
out_global_store.MoveDstSliceWindow(out_grid_desc_n_h_w, loop_step);
} while(--num_iter);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment