Unverified Commit e17c0d80 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Reduction in Composable Kernel (#82)



* Initial adding of generic reduction

* Initial adding of generic reduction ...

* Updates to make compiling done

* clang-format all files

* clang-format some files again

* Renaming in profiler/include/profile_reduce.hpp

* Updates and make BlockWise cases passed

* Updates and make ThreadWise and MultiBlockTwoCall cases passed

* Remove the support for MUL and NORM1 reduceOp from the profiler and the device instances

* Change to replace the dim0_max_vector_size/dim1_max_vector_size template argument in the device reduce classes

* format

* adding pooling

* added max and average pooling

* comment out cout and kernel timing

* Tiny simplification in profiler/reduce_profiler.cpp

* Add example for reduce_blockwise

* Tiny updates

* Change to pass the ElementWiseOp from device layer to kernel

* Fix the vectorDim and vectorSize in Device layer

* Enable vector load on both dim0 and dim1 for Threadwise method

* Tiny updates

* Change to let the user to pass the preUnaryOp and posUnaryOp

* Make pooling example work

* split device_reduce_instance into two libraries

* Tiny update

* Replace nanPropaOpt enum by boolean propagate_nan

* Simplification in DeviceReduce layer codes

* update build

* Change to clarify the difference between ck::half_t and half_float::half

* Renaming in all the reduction codes

* Add VectorSize as template parameter for device layer

* Add BetaIsZero as kernel template and as AccDataType for alpha

* print

* Small updates for pooling

* Updates for host_generic_reduction for reference

* Update to make AVG pooling pass

* Update to make MAX pooling with indices output pass

* fix

* add OutDst vector store to threadwise reduction and pooling

* tweak

* turn off check_indices that caused build issue

* refactor pooling

* clean up

* turn off check_indices for building issue for php-compiler

* add more tile size for odd C

* tweak conv for odd C

* update script

* clean up elementwise op

* add hack in reduction_operator.hpp to avoid compile error. To fix it, need to use element_wise_op in reduction op

* Add OutVectorSize as device and kernel tunable, also update to Elementwise Operations

* Move reduce operator mapping to host layer file reduction_operator_mapping.hpp from reduction_operator.hpp

* Change to the unary operators

* Move the definitions of unary operations to element_wise_operation.hpp

* re-org files

* Refine in device interfaces and multiblock kernels

* Split the reduction configurations into instances for specific methods

* Update in getTypeString() of device pool2d

* Renaming in host and kernel

* Tiny update in profiler/src/profiler.cpp

* Uncomment in device_operation/CMakeLists.txt to enable the building of all operations

* Make check_indices a templated function to remove some linking issue

* Renaming in the profiler reduce module

* Add support for double Reduction (but disable MultiblockAtomicAdd for double)

* Tiny correction of literal string

* Rename DevicePoolFwd to DevicePool2dFwd

* Split device_reduce_instance_xxx.cpp files according to the data types to speed up compiling

* Add comments for lists of configurations, lists of instances and references of add_reduce_instances_xxx

* Remove un-used header file gridwise_generic_reduction_wrapper_common.hpp

* Renaming and refining in the Reduction codes

* Tiny change in the unary operators

* Renaming symbols and files

* Renaming symbols in the kernels

* Move kernel kernel_set_buffer_value to separate file

* Add IndexDataType template parameter for kernels and use int32_t as index data type in device layer

* Tiny update in the kernels

* Remove definition of sqrtf()/isnan()/abs() for half_t due to some ADL issue

* Simplify a helper function in device layer

* Tiny adjustment in testing data initialization

* Renaming in kernel/device/host

* Add two testing scripts for reduction

* Refine the Unary operators in element_wise_operation.hpp

* Update in the reduce profiler module

* Update to the reduction testing scripts

* reduce compile parallelism

* change CI docker to rocm5.0

* remove unused variables

* fix build
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 12dfba3d
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_multiblock.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
void* __restrict__ ws_global)
{
(void)GridSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
auto src2dDesc = transform_tensor_descriptor(
one_dim_srcDesc,
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
constexpr int invariantLen = 1;
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
const index_t reduceSizePerBlock =
(((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) *
copySliceLen;
if constexpr(src2d_need_padding)
{
const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pass_through_transform(invariantLen),
make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
};
template <index_t srcDims>
struct get_ref_desc_types
{
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_src2dDesc =
transform_tensor_descriptor(ref_one_dim_srcDesc,
make_tuple(make_unmerge_transform(
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the BlockWise and MultiBlock method
using refType_src2dDesc_padded_34 = decltype(
transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pass_through_transform(ref_invariantLen),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
};
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_dst_global;
(void)indices_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
GredAccessesPerThreadInBlock>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
BlkGroupSize,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<srcDataType* const __restrict__>(ws_buf1_global),
static_cast<int* const __restrict__>(ws_buf2_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_multiblock.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)GridSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
auto src2dDesc =
transform_tensor_descriptor(srcDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
const index_t reduceSizePerBlock =
(((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) *
copySliceLen;
if constexpr(src2d_need_padding)
{
const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pass_through_transform(invariantLen),
make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
};
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
struct get_ref_desc_types
{
static constexpr auto ref_toReduceDimLengths =
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
static constexpr auto ref_invariantDimLengths =
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the BlockWise and MultiBlock method
using refType_src2dDesc_padded_34 = decltype(
transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pass_through_transform(ref_invariantLen),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
};
using refType_src2dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_dst_global;
(void)indices_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
GredAccessesPerThreadInBlock>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
BlkGroupSize,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<srcDataType* const __restrict__>(ws_buf1_global),
static_cast<int* const __restrict__>(ws_buf2_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
auto src2dDesc = transform_tensor_descriptor(
one_dim_srcDesc,
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
constexpr int invariantLen = 1;
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = GredThreadBufferLength;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstdDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
template <index_t srcDims>
struct get_ref_desc_types
{
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_src2dDesc =
transform_tensor_descriptor(ref_one_dim_srcDesc,
make_tuple(make_unmerge_transform(
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
};
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)BlkGroupSize;
(void)ws_buf2_bytes_offset;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
true,
true,
GredThreadBufferLength>;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(nullptr),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
auto src2dDesc =
transform_tensor_descriptor(srcDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = GredThreadBufferLength;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}
};
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
struct get_ref_desc_types
{
static constexpr auto ref_toReduceDimLengths =
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
static constexpr auto ref_invariantDimLengths =
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
};
using refType_src2dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)BlkGroupSize;
(void)ws_buf2_bytes_offset;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
true,
true,
GredThreadBufferLength>;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(nullptr),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
auto src2dDesc = transform_tensor_descriptor(
one_dim_srcDesc,
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
constexpr int invariantLen = 1;
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
template <index_t srcDims>
struct get_ref_desc_types
{
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_src2dDesc =
transform_tensor_descriptor(ref_one_dim_srcDesc,
make_tuple(make_unmerge_transform(
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
};
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)BlkGroupSize;
(void)ws_buf2_bytes_offset;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce =
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
true,
true,
GredAccessesPerThreadInWarp>;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(nullptr),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int BlkGroupSize,
int inLength0,
int inLength1,
int inLength2,
int inLength3,
int inLength4,
int inLength5,
int inStride0,
int inStride1,
int inStride2,
int inStride3,
int inStride4,
int inStride5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
auto src2dDesc =
transform_tensor_descriptor(srcDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}
};
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
struct get_ref_desc_types
{
static constexpr auto ref_toReduceDimLengths =
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
static constexpr auto ref_invariantDimLengths =
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
// don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
ref_srcDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
make_tuple(invariantDims{}, toReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
};
using refType_src2dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
int BlkGroupSize,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)BlkGroupSize;
(void)ws_buf2_bytes_offset;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce =
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
true,
true,
GredAccessesPerThreadInWarp>;
constexpr int RunId = need_indices ? 2 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(p_src_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(nullptr),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_blockwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)GridSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
if constexpr(src2d_need_padding)
{
const auto srcPad =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pass_through_transform(invariantLen),
make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the BlockWise and MultiBlock method
using refType_src2dDesc_padded_34 = decltype(
transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pass_through_transform(ref_invariantLen),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInBlock>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_blockwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
int BlkGroupSize,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)GridSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
if constexpr(src2d_need_padding)
{
const auto srcPad =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pass_through_transform(invariantLen),
make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
};
template <index_t dstDims>
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths =
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(ref_tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
// used by the BlockWise and MultiBlock method
using refType_src2dDesc_padded_34 = decltype(
transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pass_through_transform(ref_invariantLen),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInBlock>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = GredThreadBufferLength;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredThreadBufferLength>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
int BlkGroupSize,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = GredThreadBufferLength;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}
};
template <index_t dstDims>
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths =
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(ref_tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredThreadBufferLength>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce =
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInWarp>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
// helper functions using variadic template arguments
template <index_t... Ns>
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t... Ns>
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
{
return make_tuple(Ns...);
};
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
int BlkGroupSize,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto dst1dDesc = transform_tensor_descriptor(
dstDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}
};
template <index_t dstDims>
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths =
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(ref_tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce =
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInWarp>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
......@@ -111,7 +111,35 @@ set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
)
# device_reduce_instance
set(DEVICE_REDUCE_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f16_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f32_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f64_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f64_f64_f64.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f16_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f32_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f64_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f64_f64_f64.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp;
)
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE})
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
......@@ -120,8 +148,8 @@ add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURC
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE})
add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE})
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE})
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
......@@ -134,6 +162,7 @@ target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<
target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_conv2d_bwd_data_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_reduce_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_compile_features(device_gemm_instance PUBLIC)
target_compile_features(device_gemm_bias_2d_instance PUBLIC)
......@@ -146,6 +175,7 @@ target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC)
target_compile_features(device_conv2d_bwd_data_instance PUBLIC)
target_compile_features(device_reduce_instance PUBLIC)
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
......@@ -158,6 +188,7 @@ set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_I
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
......@@ -170,3 +201,4 @@ install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib)
install(TARGETS device_reduce_instance LIBRARY DESTINATION lib)
......@@ -549,8 +549,11 @@ struct
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
output_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
......@@ -625,8 +628,11 @@ struct
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> input_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
......@@ -638,6 +644,28 @@ struct
float Run(const Argument& arg, int nrepeat = 1)
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -656,6 +684,7 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -526,8 +526,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
output_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
......@@ -590,8 +593,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> input_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
......@@ -603,6 +609,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float Run(const Argument& arg, int nrepeat = 1)
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -618,6 +646,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -498,8 +498,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
output_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
......@@ -551,8 +554,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> input_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
......@@ -564,6 +570,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, int nrepeat = 1)
{
#if 0
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
<< "K " << arg.Conv_K_ << ", "
<< "C " << arg.Conv_C_ << ", " << std::endl;
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
<< arg.conv_filter_strides_[1] << ", " << std::endl;
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -598,6 +626,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5)
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -452,6 +452,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, int nrepeat = 1)
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -464,6 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
#ifndef DEVICE_POOL2D_FWD_HPP
#define DEVICE_POOL2D_FWD_HPP
#include <iostream>
#include <array>
#include "device_base.hpp"
#include "reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::ReduceTensorOp_t ReduceOpId>
struct DevicePool2dFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* in_dev,
void* out_dev,
void* out_indices_dev,
ck::index_t N,
ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides,
std::array<ck::index_t, 2> input_left_pads,
std::array<ck::index_t, 2> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::ReduceTensorOp_t ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_POOL2D_FWD_NHWC_NHWC_HPP
#define DEVICE_POOL2D_FWD_NHWC_NHWC_HPP
#include <iostream>
#include <sstream>
#include "device_pool2d_fwd.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "reduction_operator_mapping.hpp"
#include "gridwise_2d_reduction_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename OutDataType,
typename AccDataType,
ck::ReduceTensorOp_t ReduceOpId,
bool NeedIndices,
ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize,
ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd<ReduceOpId>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
static constexpr bool BetaIsZero = true;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
static constexpr ck::index_t ReduceM_BlockTileSize =
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
static constexpr ck::index_t ReduceK_BlockTileSize =
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides,
std::array<ck::index_t, 2> input_left_pads,
std::array<ck::index_t, 2> input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = window_spatial_lengths[0];
const index_t X = window_spatial_lengths[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ReduceMRaw = N * Ho * Wo * C;
const index_t ReduceMPad =
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
const index_t ReduceKRaw = Y * X;
const index_t ReduceKPad =
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
// A[ReduceM, ReduceK]
const auto in_grid_desc_n_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_grid_desc_reducemraw_reducekraw =
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
in_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const auto out_grid_desc_reducemraw =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
const auto out_grid_desc_reducem = transform_tensor_descriptor(
out_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
// TODO
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
int* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::array<ck::index_t, 2>& input_spatial_lengths,
std::array<ck::index_t, 2>& window_spatial_lengths,
std::array<ck::index_t, 2>& output_spatial_lengths,
std::array<ck::index_t, 2>& window_strides,
std::array<ck::index_t, 2>& input_left_pads,
std::array<ck::index_t, 2>& input_right_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
// TODO: is this correct?
if constexpr(ReduceOpId == ck::ReduceTensorOp_t::AVG)
{
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
in_element_op_ = InElementwiseOperation{divider};
acc_element_op_ = AccElementwiseOperation{divider};
}
}
const InDataType* p_in_dev_;
OutDataType* p_out_dev_;
int* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
ck::index_t invariant_lowest_length_;
ck::index_t reduce_lowest_length_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
false, // propagate_nan
BetaIsZero,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_,
arg.in_element_op_,
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
{
return (false);
}
return (true);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::array<ck::index_t, 2> input_spatial_lengths,
std::array<ck::index_t, 2> window_spatial_lengths,
std::array<ck::index_t, 2> output_spatial_lengths,
std::array<ck::index_t, 2> window_strides,
std::array<ck::index_t, 2> input_left_pads,
std::array<ck::index_t, 2> input_right_pads) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev),
static_cast<int*>(p_out_indices_dev),
N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_HPP
#define DEVICE_REDUCE_HPP
#include <vector>
#include <memory>
#include <iostream>
#include "common_header.hpp"
#include "device_base.hpp"
#include "reduction_enums.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator
{
virtual size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths)
{
(void)inLengths;
return (0);
};
virtual bool HasFurtherCall() { return (false); };
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
{
(void)argPtr;
return (std::vector<int>{0, 0});
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation& inElementwiseOp,
const AccElementwiseOperation& accElementwiseOp) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation, typename AccElementwiseOperation>
using DeviceReducePtr =
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment