Commit f650c019 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove global load/store loop in kernel code

parent 684e012e
...@@ -8,11 +8,11 @@ using BDataType = F16; ...@@ -8,11 +8,11 @@ using BDataType = F16;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>; < ADataType, BDataType, PassThrough, 3, 256, 16, 32, 32, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on // clang-format on
#include "run_permute_example.inc" #include "run_permute_example.inc"
......
...@@ -8,11 +8,11 @@ using BDataType = F64; ...@@ -8,11 +8,11 @@ using BDataType = F64;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 256, 16, 16, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>; < ADataType, BDataType, PassThrough, 3, 256, 4, 16, 16, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on // clang-format on
#define NUM_ELEMS_IN_BUNDLE 4 #define NUM_ELEMS_IN_BUNDLE 4
......
...@@ -8,11 +8,11 @@ using BDataType = F16; ...@@ -8,11 +8,11 @@ using BDataType = F16;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>; < ADataType, BDataType, PassThrough, 3, 256, 16, 32, 32, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on // clang-format on
#include "run_permute_example.inc" #include "run_permute_example.inc"
......
...@@ -79,6 +79,7 @@ template <typename InDataType, ...@@ -79,6 +79,7 @@ template <typename InDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t NumDim, index_t NumDim,
index_t BlockSize, index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock, index_t HPerBlock,
index_t WPerBlock, index_t WPerBlock,
index_t InBlockLdsExtraW, index_t InBlockLdsExtraW,
...@@ -94,6 +95,7 @@ struct DevicePermute ...@@ -94,6 +95,7 @@ struct DevicePermute
ElementwiseOperation, ElementwiseOperation,
NumDim, NumDim,
BlockSize, BlockSize,
NPerBlock,
HPerBlock, HPerBlock,
WPerBlock, WPerBlock,
InBlockLdsExtraW, InBlockLdsExtraW,
...@@ -141,7 +143,7 @@ struct DevicePermute ...@@ -141,7 +143,7 @@ struct DevicePermute
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadTensorDescriptor( return PadTensorDescriptor(
desc_n_h_w, make_tuple(1, HPerBlock, WPerBlock), Sequence<false, true, true>{}); desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
} }
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
...@@ -154,6 +156,7 @@ struct DevicePermute ...@@ -154,6 +156,7 @@ struct DevicePermute
OutDataType, OutDataType,
ElementwiseOperation, ElementwiseOperation,
BlockSize, BlockSize,
NPerBlock,
HPerBlock, HPerBlock,
WPerBlock, WPerBlock,
InBlockLdsExtraW, InBlockLdsExtraW,
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <index_t HPerBlock, index_t WPerBlock, typename GridDesc> template <index_t NPerBlock, index_t HPerBlock, index_t WPerBlock, typename GridDesc>
struct GridwisePermuteBlock2TileMap struct GridwisePermuteBlock2TileMap
{ {
static constexpr index_t NumDim = GridDesc::GetNumOfDimension(); static constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(2 <= NumDim); static_assert(3 <= NumDim);
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -36,10 +36,11 @@ struct GridwisePermuteBlock2TileMap ...@@ -36,10 +36,11 @@ struct GridwisePermuteBlock2TileMap
__host__ constexpr index_t CalculateGridSize(const GridDesc& desc) const __host__ constexpr index_t CalculateGridSize(const GridDesc& desc) const
{ {
const auto N0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock); const auto H0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock); const auto W0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
const index_t grid_size = H0 * W0; const index_t grid_size = N0 * H0 * W0;
return grid_size; return grid_size;
} }
...@@ -51,15 +52,17 @@ struct GridwisePermuteBlock2TileMap ...@@ -51,15 +52,17 @@ struct GridwisePermuteBlock2TileMap
auto block_1d_id = idx_top[I0]; auto block_1d_id = idx_top[I0];
const auto N0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock); const auto H0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock); const auto W0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
block_1d_id = block_1d_id % (H0 * W0); block_1d_id = block_1d_id % (N0 * H0 * W0);
index_t idx_H0 = block_1d_id / W0; index_t idx_N0 = block_1d_id / (H0 * W0);
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
index_t idx_W0 = block_1d_id % W0; index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_H0, idx_W0); return make_tuple(idx_N0, idx_H0, idx_W0);
} }
private: private:
...@@ -98,6 +101,7 @@ template <typename InGridDesc, ...@@ -98,6 +101,7 @@ template <typename InGridDesc,
typename OutDataType, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t BlockSize, index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock, index_t HPerBlock,
index_t WPerBlock, index_t WPerBlock,
index_t InBlockLdsExtraW, index_t InBlockLdsExtraW,
...@@ -124,14 +128,15 @@ struct GridwisePermute ...@@ -124,14 +128,15 @@ struct GridwisePermute
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = using DefaultBlock2TileMap =
detail::GridwisePermuteBlock2TileMap<HPerBlock, WPerBlock, InGridDesc>; detail::GridwisePermuteBlock2TileMap<NPerBlock, HPerBlock, WPerBlock, InGridDesc>;
__host__ __device__ static constexpr auto GetInBlockDesc_1_HPerBlock_WPerBlock() __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
{ {
return make_naive_tensor_descriptor(make_tuple(1, Number<HPerBlock>{}, Number<WPerBlock>{}), return make_naive_tensor_descriptor(
make_tuple(Number<WPerBlock + InBlockLdsExtraW>{}, make_tuple(Number<NPerBlock>{}, Number<HPerBlock>{}, Number<WPerBlock>{}),
Number<WPerBlock + InBlockLdsExtraW>{}, make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
I1)); Number<WPerBlock + InBlockLdsExtraW>{},
I1));
} }
template <typename GridDesc> template <typename GridDesc>
...@@ -155,9 +160,11 @@ struct GridwisePermute ...@@ -155,9 +160,11 @@ struct GridwisePermute
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto in_block_desc_1_hperblock_wperblock = GetInBlockDesc_1_HPerBlock_WPerBlock(); constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
return in_block_desc_1_hperblock_wperblock.GetElementSpaceSize() * sizeof(InDataType); return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
sizeof(InDataType);
} }
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc) __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
...@@ -204,20 +211,24 @@ struct GridwisePermute ...@@ -204,20 +211,24 @@ struct GridwisePermute
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
const index_t h_block_data_idx_on_grid = const index_t h_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * HPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
const index_t w_block_data_idx_on_grid = const index_t w_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * WPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
// Input slice in LDS memory, dst of blockwise copy // Input slice in LDS memory, dst of blockwise copy
constexpr auto in_block_desc_1_hperblock_wperblock = GetInBlockDesc_1_HPerBlock_WPerBlock(); constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared), static_cast<InDataType*>(p_shared),
in_block_desc_1_hperblock_wperblock.GetElementSpaceSize()); in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
using BlockSliceLengths = Sequence<1, HPerBlock, WPerBlock>; using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>; using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t SrcVectorDimAfterMerge = constexpr index_t SrcVectorDimAfterMerge =
...@@ -228,34 +239,34 @@ struct GridwisePermute ...@@ -228,34 +239,34 @@ struct GridwisePermute
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc); const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
auto in_global_load = auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
PassThrough, PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
BlockSliceLengths, BlockSliceLengths,
InBlockTransferThreadClusterLengths, InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder, InBlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
InDataType, InDataType,
decltype(in_grid_desc_n_h_w), decltype(in_grid_desc_n_h_w),
decltype(in_block_desc_1_hperblock_wperblock), decltype(in_block_desc_nperblock_hperblock_wperblock),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
SrcVectorDimAfterMerge, SrcVectorDimAfterMerge,
2, 2,
SrcScalarPerVector, SrcScalarPerVector,
1, 1,
1, 1,
1, 1,
true, true,
true>( true>(in_grid_desc_n_h_w,
in_grid_desc_n_h_w, make_multi_index(
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{}, PassThrough{},
in_block_desc_1_hperblock_wperblock, in_block_desc_nperblock_hperblock_wperblock,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
PassThrough{}); PassThrough{});
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc); const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
...@@ -267,55 +278,46 @@ struct GridwisePermute ...@@ -267,55 +278,46 @@ struct GridwisePermute
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
auto out_global_store = auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
ElementwiseOperation, ElementwiseOperation,
PassThrough, PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
BlockSliceLengths, BlockSliceLengths,
InBlockTransferThreadClusterLengths, InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder, InBlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
OutDataType, OutDataType,
decltype(in_block_desc_1_hperblock_wperblock), decltype(in_block_desc_nperblock_hperblock_wperblock),
decltype(out_grid_desc_n_h_w), decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
InBlockTransferAccessOrder, InBlockTransferAccessOrder,
2, 2,
DstVectorDimAfterMerge, DstVectorDimAfterMerge,
1, 1,
DstScalarPerVector, DstScalarPerVector,
1, 1,
1, 1,
true, true,
true>( true>(in_block_desc_nperblock_hperblock_wperblock,
in_block_desc_1_hperblock_wperblock, make_multi_index(0, 0, 0),
make_multi_index(0, 0, 0), PassThrough{},
PassThrough{}, out_grid_desc_n_h_w,
out_grid_desc_n_h_w, make_multi_index(
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op); elementwise_op);
const auto loop_step = make_multi_index(1, 0, 0); in_global_load.Run(in_grid_desc_n_h_w,
index_t num_iter = in_grid_desc_n_h_w.GetLength(I0); in_global_buf,
do in_block_desc_nperblock_hperblock_wperblock,
{ in_block_buf,
in_global_load.Run(in_grid_desc_n_h_w, I0);
in_global_buf,
in_block_desc_1_hperblock_wperblock, out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
in_block_buf, in_block_buf,
I0); out_grid_desc_n_h_w,
out_global_buf,
in_global_load.MoveSrcSliceWindow(in_grid_desc_n_h_w, loop_step); I0);
out_global_store.Run(in_block_desc_1_hperblock_wperblock,
in_block_buf,
out_grid_desc_n_h_w,
out_global_buf,
I0);
out_global_store.MoveDstSliceWindow(out_grid_desc_n_h_w, loop_step);
} while(--num_iter);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment