"vscode:/vscode.git/clone" did not exist on "14fd82f3156ecffc53defd42e8eb66d19f220697"
Commit 5cfa0368 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add comments in 'GridwisePermute'

parent 057ffb90
...@@ -133,6 +133,7 @@ struct GridwisePermute ...@@ -133,6 +133,7 @@ struct GridwisePermute
using DefaultBlock2TileMap = Block2TileMap; using DefaultBlock2TileMap = Block2TileMap;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock() __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -142,6 +143,9 @@ struct GridwisePermute ...@@ -142,6 +143,9 @@ struct GridwisePermute
I1)); I1));
} }
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template <typename GridDesc> template <typename GridDesc>
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc) __host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
{ {
...@@ -211,6 +215,7 @@ struct GridwisePermute ...@@ -211,6 +215,7 @@ struct GridwisePermute
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize()); p_out_global, out_grid_desc.GetElementSpaceSize());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -223,7 +228,7 @@ struct GridwisePermute ...@@ -223,7 +228,7 @@ struct GridwisePermute
const index_t w_block_data_idx_on_grid = const index_t w_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
// Input slice in LDS memory, dst of blockwise copy // create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr auto in_block_desc_nperblock_hperblock_wperblock = constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock(); GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
...@@ -240,8 +245,11 @@ struct GridwisePermute ...@@ -240,8 +245,11 @@ struct GridwisePermute
using ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::element_wise::PassThrough;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc); const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1< auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
...@@ -271,8 +279,11 @@ struct GridwisePermute ...@@ -271,8 +279,11 @@ struct GridwisePermute
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}); PassThrough{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc); const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
// create transposed view of output tensor
const auto out_grid_desc_n_h_w = transform_tensor_descriptor( const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
out_grid_desc_n_w_h, out_grid_desc_n_w_h,
make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)), make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
...@@ -281,6 +292,7 @@ struct GridwisePermute ...@@ -281,6 +292,7 @@ struct GridwisePermute
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1< auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment