"...composable_kernel.git" did not exist on "b2df701822843b357b14a476cd0800098d47a888"
Commit 5502bac2 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Embed shape info in name of descriptor constructor

parent 5a7c738d
...@@ -126,7 +126,7 @@ struct GridwisePermute ...@@ -126,7 +126,7 @@ struct GridwisePermute
using DefaultBlock2TileMap = using DefaultBlock2TileMap =
detail::GridwisePermuteBlock2TileMap<HPerBlock, WPerBlock, InGridDesc>; detail::GridwisePermuteBlock2TileMap<HPerBlock, WPerBlock, InGridDesc>;
__host__ __device__ static constexpr auto GetInBlockDesc() __host__ __device__ static constexpr auto GetInBlockDesc_1_HPerBlock_WPerBlock()
{ {
return make_naive_tensor_descriptor(make_tuple(1, Number<HPerBlock>{}, Number<WPerBlock>{}), return make_naive_tensor_descriptor(make_tuple(1, Number<HPerBlock>{}, Number<WPerBlock>{}),
make_tuple(Number<WPerBlock + InBlockLdsExtraW>{}, make_tuple(Number<WPerBlock + InBlockLdsExtraW>{},
...@@ -155,9 +155,9 @@ struct GridwisePermute ...@@ -155,9 +155,9 @@ struct GridwisePermute
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto in_block_desc = GetInBlockDesc(); constexpr auto in_block_desc_1_hperblock_wperblock = GetInBlockDesc_1_HPerBlock_WPerBlock();
return in_block_desc.GetElementSpaceSize() * sizeof(InDataType); return in_block_desc_1_hperblock_wperblock.GetElementSpaceSize() * sizeof(InDataType);
} }
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc) __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
...@@ -211,10 +211,11 @@ struct GridwisePermute ...@@ -211,10 +211,11 @@ struct GridwisePermute
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * WPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * WPerBlock);
// Input slice in LDS memory, dst of blockwise copy // Input slice in LDS memory, dst of blockwise copy
constexpr auto in_block_desc = GetInBlockDesc(); constexpr auto in_block_desc_1_hperblock_wperblock = GetInBlockDesc_1_HPerBlock_WPerBlock();
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared), in_block_desc.GetElementSpaceSize()); static_cast<InDataType*>(p_shared),
in_block_desc_1_hperblock_wperblock.GetElementSpaceSize());
using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>; using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>; using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
...@@ -238,7 +239,7 @@ struct GridwisePermute ...@@ -238,7 +239,7 @@ struct GridwisePermute
InDataType, InDataType,
InDataType, InDataType,
decltype(in_grid_desc_n_h_w), decltype(in_grid_desc_n_h_w),
decltype(in_block_desc), decltype(in_block_desc_1_hperblock_wperblock),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
SrcVectorDimAfterMerge, SrcVectorDimAfterMerge,
...@@ -252,7 +253,7 @@ struct GridwisePermute ...@@ -252,7 +253,7 @@ struct GridwisePermute
in_grid_desc_n_h_w, in_grid_desc_n_h_w,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{}, PassThrough{},
in_block_desc, in_block_desc_1_hperblock_wperblock,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}); PassThrough{});
...@@ -276,7 +277,7 @@ struct GridwisePermute ...@@ -276,7 +277,7 @@ struct GridwisePermute
InBlockTransferThreadClusterArrangeOrder, InBlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
OutDataType, OutDataType,
decltype(in_block_desc), decltype(in_block_desc_1_hperblock_wperblock),
decltype(out_grid_desc_n_h_w), decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
...@@ -288,7 +289,7 @@ struct GridwisePermute ...@@ -288,7 +289,7 @@ struct GridwisePermute
1, 1,
true, true,
true>( true>(
in_block_desc, in_block_desc_1_hperblock_wperblock,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}, PassThrough{},
out_grid_desc_n_h_w, out_grid_desc_n_h_w,
...@@ -299,12 +300,19 @@ struct GridwisePermute ...@@ -299,12 +300,19 @@ struct GridwisePermute
index_t num_iter = in_grid_desc_n_h_w.GetLength(I0); index_t num_iter = in_grid_desc_n_h_w.GetLength(I0);
do do
{ {
in_global_load.Run(in_grid_desc_n_h_w, in_global_buf, in_block_desc, in_block_buf, I0); in_global_load.Run(in_grid_desc_n_h_w,
in_global_buf,
in_block_desc_1_hperblock_wperblock,
in_block_buf,
I0);
in_global_load.MoveSrcSliceWindow(in_grid_desc_n_h_w, loop_step); in_global_load.MoveSrcSliceWindow(in_grid_desc_n_h_w, loop_step);
out_global_store.Run( out_global_store.Run(in_block_desc_1_hperblock_wperblock,
in_block_desc, in_block_buf, out_grid_desc_n_h_w, out_global_buf, I0); in_block_buf,
out_grid_desc_n_h_w,
out_global_buf,
I0);
out_global_store.MoveDstSliceWindow(out_grid_desc_n_h_w, loop_step); out_global_store.MoveDstSliceWindow(out_grid_desc_n_h_w, loop_step);
} while(--num_iter); } while(--num_iter);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment