Commit dfb80c4e authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

[Enhancements] Several bugfixes and refactoring of dynamic generic reduction (#1156)

* Squashed 'src/composable_kernel/' content from commit f6edda61

git-subtree-dir: src/composable_kernel
git-subtree-split: f6edda61

* add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files

* Squashed 'src/composable_kernel/' changes from f6edda61..5781adf5

5781adf5 Update develop (#5) (#6)
97e6d514 Merge pull request #4 from ROCmSoftwarePlatform/separate_online_compile
7b1ec41e refactor
49c33aae refactor
54b3e73d rename

git-subtree-dir: src/composable_kernel
git-subtree-split: 5781adf5



* fix

* refactor

* remove online compilation from CK

* refactor

* fix

* add ctest

* tidy

* add tidy

* tidy

* tidy

* tidy

* tidy

* tidy

* tidy

* tidy

* tidy

* tidy

* add c-style pointer cast

* vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast

* fix clang warning suppression

* tidy

* suppress cppcheck

* fix enum issue

* revert chagnes to hip build

* fix kernel filename

* update CK build script

* rename

* rename

* make innner product compatiable on gfx900

* Update src/include/miopen/solver/ck_utility_common.hpp
Co-authored-by: default avatarJD <Jehandad.Khan@amd.com>

* compiler parameter use stream

* use int instead of index_t in kernel wrapper

* DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element

* refactor

* refactor

* change cmakelist

* change ck common utility

* fix

* Squashed 'src/composable_kernel/' changes from 5781adf5..31b40352

31b40352 Merge pull request #16 from ROCmSoftwarePlatform/develop
b62bf8c3 Merge pull request #14 from ROCmSoftwarePlatform/miopen_downstream_init_integration
ccc4a1d3 Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration
67ad47e7 refactor
16effa76 refactor
a91b68df DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element
2cbabbba use int instead of index_t in kernel wrapper
0834bc76 compiler parameter use stream
f2ac7832 make innner product compatiable on gfx900
4e57b30a rename
c03045ce rename
b2589957 update CK build script
2c48039d fix kernel filename
d626dccc fix enum issue
643ebd4f tidy
ddd49ec9 fix clang warning suppression
4f566c62 vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast
172036d7 add c-style pointer cast
76f31319 tidy
d1842890 tidy
f885c131 tidy
80120f0a tidy
c3efeb5e tidy
56fc0842 tidy
54fba515 tidy
e62bae7a tidy
24c87289 add tidy
61487e0a fix
ae98b52a remove online compilation from CK
cb954213 refactor
73ca9701 Merge commit '437cc595c6e206dfebb118985b5171bbc1e29eab' into composable_kernel_init_integration_v3
3b866461 Merge pull request #7 from ROCmSoftwarePlatform/master
d09ea4f4 Update develop (#5)
3d32ae94 add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files

git-subtree-dir: src/composable_kernel
git-subtree-split: 31b40352



* Tiny fix in using data type template parameters in blockwise and direct_threadwise kernel

* Fix with regard to implementing GetZeroVal() in both kernel and host

* Avoid convert to compType from dstDataType before writting the output value

* Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator

* Add CONSTANT decorator for descriptor read buffer

* Use get_thread_local_1d_id() for thread local Id

* Rename GetZeroVal() to GetReductionZeroVal() in the kernels

* Remove constexpr from initialized zeroVal and tiny fix in reduction_operator.hpp

* Occasional tiny simplification and update in the kernel files

* Update in src/reducetensor.cpp for consistent IDs passing to the kernel

* Update to re-order tensor dimensions on the host, split second_call kernel wrapper files and simplify reduce_all kernel wrappers

* Update to remove OpenCL tidy checking failures

* Small updates in src/reducetensor.cpp

* Update for better readability

* Remove unused codes and not-needed template parameters in the kernel wrappers
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
Co-authored-by: default avatarJD <Jehandad.Khan@amd.com>
parent 9e80cdce
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = GredThreadBufferLength;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredThreadBufferLength>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
...@@ -42,12 +42,8 @@ using compType = ...@@ -42,12 +42,8 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN ? NanPropagation_t::NOT_PROPAGATE_NAN
...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, ...@@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, ...@@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
} }
}; };
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims> template <index_t dstDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_tupleDstLengths = static constexpr auto ref_tupleDstLengths =
...@@ -217,16 +203,11 @@ struct get_ref_desc_types ...@@ -217,16 +203,11 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>:: typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
refType_src2dDesc_padded_12; using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, ...@@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)p_src_global; (void)p_src_global;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096; void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)BlkGroupSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
if constexpr(src2d_need_padding)
{
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
const auto srcPad2 =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if constexpr(dst1d_need_padding)
{
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 =
transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the DirectThreadWise and DirectWarpWise method
using refType_src2dDesc_padded_12 =
decltype(transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce =
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInWarp>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
...@@ -42,12 +42,8 @@ using compType = ...@@ -42,12 +42,8 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN ? NanPropagation_t::NOT_PROPAGATE_NAN
...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, ...@@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, ...@@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
} }
}; };
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims> template <index_t dstDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_tupleDstLengths = static constexpr auto ref_tupleDstLengths =
...@@ -218,16 +204,11 @@ struct get_ref_desc_types ...@@ -218,16 +204,11 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>:: typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
refType_src2dDesc_padded_12; using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, ...@@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)p_src_global; (void)p_src_global;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096; void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment