Commit 7835e2e7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove '1d' in identifiers

parent 51b2b081
...@@ -137,11 +137,11 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -137,11 +137,11 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{}); desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
} }
using InGrid1dDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGrid1dDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<InGrid1dDesc, using GridwisePermute = GridwisePermute<InGridDesc,
OutGrid1dDesc, OutGridDesc,
InDataTypePointer, InDataTypePointer,
OutDataTypePointer, OutDataTypePointer,
ElementwiseOperation, ElementwiseOperation,
...@@ -164,21 +164,21 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -164,21 +164,21 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
ElementwiseOperation elementwise_op) ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)), : in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)), out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)),
in_grid_1d_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)), in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
out_grid_1d_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)), out_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
inLengths_(inLengths), inLengths_(inLengths),
inStrides_(inStrides), inStrides_(inStrides),
outLengths_(outLengths), outLengths_(outLengths),
outStrides_(outStrides), outStrides_(outStrides),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_1d_desc_)) block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{ {
} }
InDataTypePointer in_dev_buffer_; InDataTypePointer in_dev_buffer_;
OutDataTypePointer out_dev_buffer_; OutDataTypePointer out_dev_buffer_;
InGrid1dDesc in_grid_1d_desc_; InGridDesc in_grid_desc_;
OutGrid1dDesc out_grid_1d_desc_; OutGridDesc out_grid_desc_;
std::array<index_t, NumDim> inLengths_; std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> inStrides_; std::array<index_t, NumDim> inStrides_;
...@@ -194,11 +194,11 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -194,11 +194,11 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
{ {
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_1d_desc_); const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
const auto kernel = kernel_nd_permute<GridwisePermute, const auto kernel = kernel_nd_permute<GridwisePermute,
InGrid1dDesc, InGridDesc,
OutGrid1dDesc, OutGridDesc,
InDataTypePointer, InDataTypePointer,
OutDataTypePointer, OutDataTypePointer,
ElementwiseOperation, ElementwiseOperation,
...@@ -209,8 +209,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -209,8 +209,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.in_grid_1d_desc_, arg.in_grid_desc_,
arg.out_grid_1d_desc_, arg.out_grid_desc_,
arg.in_dev_buffer_, arg.in_dev_buffer_,
arg.out_dev_buffer_, arg.out_dev_buffer_,
arg.elementwise_op_, arg.elementwise_op_,
......
...@@ -90,14 +90,14 @@ struct Block2TileMap ...@@ -90,14 +90,14 @@ struct Block2TileMap
} // namespace detail } // namespace detail
template <typename GridwisePermute, template <typename GridwisePermute,
typename InGrid1dDesc, typename InGridDesc,
typename OutGrid1dDesc, typename OutGridDesc,
typename InDataTypePointer, typename InDataTypePointer,
typename OutDataTypePointer, typename OutDataTypePointer,
typename ElementwiseOperation, typename ElementwiseOperation,
typename Block2TileMap> typename Block2TileMap>
__global__ void kernel_nd_permute(const InGrid1dDesc in_grid_1d_desc, __global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
const OutGrid1dDesc out_grid_1d_desc, const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global, const InDataTypePointer p_in_global,
const OutDataTypePointer p_out_global, const OutDataTypePointer p_out_global,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
...@@ -105,8 +105,8 @@ __global__ void kernel_nd_permute(const InGrid1dDesc in_grid_1d_desc, ...@@ -105,8 +105,8 @@ __global__ void kernel_nd_permute(const InGrid1dDesc in_grid_1d_desc,
{ {
__shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
GridwisePermute::Run(in_grid_1d_desc, GridwisePermute::Run(in_grid_desc,
out_grid_1d_desc, out_grid_desc,
p_in_global, p_in_global,
p_out_global, p_out_global,
p_shared, p_shared,
...@@ -114,8 +114,8 @@ __global__ void kernel_nd_permute(const InGrid1dDesc in_grid_1d_desc, ...@@ -114,8 +114,8 @@ __global__ void kernel_nd_permute(const InGrid1dDesc in_grid_1d_desc,
block_2_tile_map); block_2_tile_map);
} }
template <typename InGrid1dDesc, template <typename InGridDesc,
typename OutGrid1dDesc, typename OutGridDesc,
typename InDataTypePointer, typename InDataTypePointer,
typename OutDataTypePointer, typename OutDataTypePointer,
typename ElementwiseOperation, typename ElementwiseOperation,
...@@ -128,8 +128,7 @@ template <typename InGrid1dDesc, ...@@ -128,8 +128,7 @@ template <typename InGrid1dDesc,
index_t OutScalarPerVector> index_t OutScalarPerVector>
struct GridwisePermute struct GridwisePermute
{ {
static_assert(InGrid1dDesc::GetNumOfDimension() == 3 && static_assert(InGridDesc::GetNumOfDimension() == 3 && OutGridDesc::GetNumOfDimension() == 3);
OutGrid1dDesc::GetNumOfDimension() == 3);
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -139,8 +138,7 @@ struct GridwisePermute ...@@ -139,8 +138,7 @@ struct GridwisePermute
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = using DefaultBlock2TileMap = detail::Block2TileMap<Sequence<HPerBlock, WPerBlock>, InGridDesc>;
detail::Block2TileMap<Sequence<HPerBlock, WPerBlock>, InGrid1dDesc>;
__host__ __device__ static constexpr auto GetInBlockDescriptor() __host__ __device__ static constexpr auto GetInBlockDescriptor()
{ {
...@@ -169,21 +167,21 @@ struct GridwisePermute ...@@ -169,21 +167,21 @@ struct GridwisePermute
return a_block_space_size_aligned * sizeof(InDataType); return a_block_space_size_aligned * sizeof(InDataType);
} }
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGrid1dDesc& desc) __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
{ {
return DefaultBlock2TileMap{desc}; return DefaultBlock2TileMap{desc};
} }
template <typename Block2TileMap> template <typename Block2TileMap>
__device__ static void Run(const InGrid1dDesc in_grid_1d_desc, __device__ static void Run(const InGridDesc in_grid_desc,
const OutGrid1dDesc out_grid_1d_desc, const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global, const InDataTypePointer p_in_global,
const OutDataTypePointer p_out_global, const OutDataTypePointer p_out_global,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
const Block2TileMap& block_2_tile_map) const Block2TileMap& block_2_tile_map)
{ {
// const index_t thread_global_id = get_thread_global_1d_id(); // const index_t thread_global_id = get_thread_global_id();
using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>; using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>;
// auto in_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr, InDataType, MPerThread, // auto in_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr, InDataType, MPerThread,
...@@ -194,16 +192,16 @@ struct GridwisePermute ...@@ -194,16 +192,16 @@ struct GridwisePermute
// true>{}; // true>{};
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_1d_desc.GetElementSpaceSize()); p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_1d_desc.GetElementSpaceSize()); p_out_global, out_grid_desc.GetElementSpaceSize());
// const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); // const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
// const index_t blockSize = get_block_size(); // const index_t blockSize = get_block_size();
// const index_t blockPerGrid = get_grid_size(); // const index_t blockPerGrid = get_grid_size();
// const auto M = in_grid_1d_desc.GetLength(I0); // const auto M = in_grid_desc.GetLength(I0);
// const index_t loop_step = blockPerGrid * blockSize * MPerThread; // const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(1, 0, 0); const auto loop_step_index = make_multi_index(1, 0, 0);
...@@ -259,7 +257,7 @@ struct GridwisePermute ...@@ -259,7 +257,7 @@ struct GridwisePermute
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
InDataType, InDataType,
InDataType, InDataType,
decltype(in_grid_1d_desc), decltype(in_grid_desc),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferDstAccessOrder, ABlockTransferDstAccessOrder,
...@@ -271,18 +269,18 @@ struct GridwisePermute ...@@ -271,18 +269,18 @@ struct GridwisePermute
1, 1,
true, true,
true>( true>(
in_grid_1d_desc, in_grid_desc,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
auto in_grid_1d_desc_tranformed = transform_tensor_descriptor( auto in_grid_desc_tranformed = transform_tensor_descriptor(
in_grid_1d_desc, in_grid_desc,
make_tuple(make_pass_through_transform(in_grid_1d_desc.GetLength(I0)), make_tuple(make_pass_through_transform(in_grid_desc.GetLength(I0)),
make_pass_through_transform(in_grid_1d_desc.GetLength(I1)), make_pass_through_transform(in_grid_desc.GetLength(I1)),
make_pass_through_transform(in_grid_1d_desc.GetLength(I2))), make_pass_through_transform(in_grid_desc.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
...@@ -297,7 +295,7 @@ struct GridwisePermute ...@@ -297,7 +295,7 @@ struct GridwisePermute
InDataType, InDataType,
OutDataType, OutDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(in_grid_1d_desc_tranformed), decltype(in_grid_desc_tranformed),
Sequence<0, 1, 2>, // ABlockTransferSrcAccessOrder Sequence<0, 1, 2>, // ABlockTransferSrcAccessOrder
Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
...@@ -310,25 +308,22 @@ struct GridwisePermute ...@@ -310,25 +308,22 @@ struct GridwisePermute
true>(a_block_desc_ak0_m_ak1, true>(a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
in_grid_1d_desc_tranformed, in_grid_desc_tranformed,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid), make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op); elementwise_op);
index_t num_iter = in_grid_1d_desc.GetLength(I0); index_t num_iter = in_grid_desc.GetLength(I0);
do do
{ {
in_global_load.Run( in_global_load.Run(
in_grid_1d_desc, in_global_buf, a_block_desc_ak0_m_ak1, a_block_buf, I0); in_grid_desc, in_global_buf, a_block_desc_ak0_m_ak1, a_block_buf, I0);
in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); in_global_load.MoveSrcSliceWindow(in_grid_desc, loop_step_index);
out_global_store.Run(a_block_desc_ak0_m_ak1, out_global_store.Run(
a_block_buf, a_block_desc_ak0_m_ak1, a_block_buf, in_grid_desc_tranformed, out_global_buf, I0);
in_grid_1d_desc_tranformed,
out_global_buf,
I0);
out_global_store.MoveDstSliceWindow(in_grid_1d_desc_tranformed, loop_step_index); out_global_store.MoveDstSliceWindow(in_grid_desc_tranformed, loop_step_index);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment