Commit 7a3b49e5 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents e07b3d8e d3051d75
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP // SPDX-License-Identifier: MIT
#define IS_KNOWN_AT_COMPILE_TIME_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "tuple.hpp" #include "tuple.hpp"
...@@ -52,4 +54,3 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -52,4 +54,3 @@ struct is_known_at_compile_time<Tuple<Ts...>>
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_MAGIC_DIVISION_HPP // SPDX-License-Identifier: MIT
#define CK_MAGIC_DIVISION_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -156,5 +158,3 @@ struct MagicDivision ...@@ -156,5 +158,3 @@ struct MagicDivision
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_MATH_HPP // SPDX-License-Identifier: MIT
#define CK_MATH_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -142,6 +144,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -142,6 +144,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
...@@ -212,5 +230,3 @@ struct less ...@@ -212,5 +230,3 @@ struct less
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#endif
#ifndef CK_MATH_V2_HPP // SPDX-License-Identifier: MIT
#define CK_MATH_V2_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath> #include <cmath>
#include "data_type.hpp"
#include "half.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +33,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +33,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,12 +60,58 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,12 +60,58 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return half_float::isnan(xx); return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#endif
#ifndef CK_MULTI_INDEX_HPP // SPDX-License-Identifier: MIT
#define CK_MULTI_INDEX_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "common_header.hpp" #include "common_header.hpp"
...@@ -8,5 +10,3 @@ ...@@ -8,5 +10,3 @@
#else #else
#include "statically_indexed_array_multi_index.hpp" #include "statically_indexed_array_multi_index.hpp"
#endif #endif
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_NUMBER_HPP #ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP #define CK_NUMBER_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_PRINT_HPP #ifndef CK_PRINT_HPP
#define CK_PRINT_HPP #define CK_PRINT_HPP
......
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_COMMON_HPP
#define CK_REDUCTION_COMMON_HPP
#include "reduction_enums.hpp" #pragma once
#include "ck/utility/reduction_enums.hpp"
namespace ck { namespace ck {
...@@ -60,6 +37,4 @@ constexpr __device__ index_t get_shift<1>() ...@@ -60,6 +37,4 @@ constexpr __device__ index_t get_shift<1>()
return (0); return (0);
} }
}; // end of namespace ck } // namespace ck
#endif
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
* #pragma once
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_ENUMS_HPP
#define CK_REDUCTION_ENUMS_HPP
namespace ck { namespace ck {
...@@ -61,6 +38,4 @@ enum struct IndicesType ...@@ -61,6 +38,4 @@ enum struct IndicesType
INDICES_8BIT = 3, INDICES_8BIT = 3,
}; };
}; // end of namespace ck } // namespace ck
#endif
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
* #pragma once
* Copyright (c) 2020 Advanced Micro Devices, Inc.
* #include "ck/utility/data_type.hpp"
* Permission is hereby granted, free of charge, to any person obtaining a copy #include "ck/utility/math_v2.hpp"
* of this software and associated documentation files (the "Software"), to deal #include "ck/utility/reduction_common.hpp"
* in the Software without restriction, including without limitation the rights #include "ck/utility/reduction_operator.hpp"
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BINOP_HPP
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T> // Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
static inline __device__ bool is_nan(T x) template <typename ReduceOperation, typename AccDataType>
{ struct AccumulateWithNanIgnore
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{ {
return (__hisnan(x)); __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(!isnan(currVal))
{
ReduceOperation{}(accuVal, currVal);
}
};
}; };
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then max(NaN, 1) returns 1
// max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
}; };
// Check for NaN; guarantees NaNs be propagated to result
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +67,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +67,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +87,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +87,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
...@@ -123,7 +111,5 @@ struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexD ...@@ -123,7 +111,5 @@ struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexD
}; };
}; };
}; // namespace detail } // namespace detail
}; // end of namespace ck } // namespace ck
#endif
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
* #pragma once
* Copyright (c) 2020 Advanced Micro Devices, Inc.
* #include "ck/ck.hpp"
* Permission is hereby granted, free of charge, to any person obtaining a copy #include "ck/utility/data_type.hpp"
* of this software and associated documentation files (the "Software"), to deal #include "ck/utility/type.hpp"
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP
#include "config.hpp"
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -36,7 +14,7 @@ namespace reduce { ...@@ -36,7 +14,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
...@@ -54,64 +32,92 @@ namespace reduce { ...@@ -54,64 +32,92 @@ namespace reduce {
// accumulated index also need be // accumulated index also need be
// changed. // changed.
template <class T>
struct Add struct Add
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::AtomicAdd || return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set; operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
}; };
template <class T>
struct Mul struct Mul
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; {
return type_convert<T>(1.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
}; };
template <class T>
struct Max struct Max
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetReductionZeroVal()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -120,31 +126,41 @@ struct Max ...@@ -120,31 +126,41 @@ struct Max
} }
}; };
template <class T>
struct Min struct Min
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetReductionZeroVal()
{ {
return NumericLimits<T>::Max(); return NumericLimits<T>::Max();
}; };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_min to be added // ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
{ {
a = b; a = b;
...@@ -153,28 +169,41 @@ struct Min ...@@ -153,28 +169,41 @@ struct Min
} }
}; };
template <class T>
struct AMax struct AMax
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -184,7 +213,7 @@ struct AMax ...@@ -184,7 +213,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
...@@ -194,8 +223,43 @@ T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operat ...@@ -194,8 +223,43 @@ T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operat
return (result); return (result);
}; };
}; // end of namespace reduce template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
} // end of namespace ck template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
#endif } // namespace reduce
} // namespace ck
#ifndef CK_SEQUENCE_HPP // SPDX-License-Identifier: MIT
#define CK_SEQUENCE_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -241,7 +243,13 @@ struct arithmetic_sequence_gen ...@@ -241,7 +243,13 @@ struct arithmetic_sequence_gen
} }
}; };
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type1 = Sequence<>;
static constexpr bool kHasContent =
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
using type = typename conditional<kHasContent, type0, type1>::type;
}; };
// uniform sequence // uniform sequence
...@@ -882,5 +890,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) ...@@ -882,5 +890,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return flag; return flag;
} }
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "tuple.hpp" #include "tuple.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP #ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP #define CK_STATIC_BUFFER_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP #ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP #define CK_STATICALLY_INDEXED_ARRAY_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
......
#ifndef CK_SYNCHRONIZATION_AMD_HPP // SPDX-License-Identifier: MIT
#define CK_SYNCHRONIZATION_AMD_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
namespace ck { namespace ck {
...@@ -18,4 +20,3 @@ __device__ void block_sync_lds() ...@@ -18,4 +20,3 @@ __device__ void block_sync_lds()
} }
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "get_id.hpp" #include "get_id.hpp"
......
#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP // SPDX-License-Identifier: MIT
#define CK_TRANSPOSE_VECTORS_AMD_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "data_type.hpp" #include "data_type.hpp"
...@@ -165,4 +167,3 @@ struct transpose_vectors<int8_t, NX, NY> ...@@ -165,4 +167,3 @@ struct transpose_vectors<int8_t, NX, NY>
}; };
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "integral_constant.hpp" #include "integral_constant.hpp"
...@@ -16,14 +19,18 @@ struct TupleElementKey ...@@ -16,14 +19,18 @@ struct TupleElementKey
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElement struct TupleElementKeyData
{ {
__host__ __device__ constexpr TupleElement() = default; #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
template < #else
typename T, __host__ __device__ constexpr TupleElementKeyData() : mData{} {}
typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false> #endif
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -31,20 +38,21 @@ struct TupleElement ...@@ -31,20 +38,21 @@ struct TupleElement
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x) __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x) __host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x) __host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
...@@ -53,7 +61,7 @@ template <typename Indices, typename... Xs> ...@@ -53,7 +61,7 @@ template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
template <index_t... Is, typename... Xs> template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>... struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
...@@ -62,13 +70,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -62,13 +70,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -77,15 +85,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -77,15 +85,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element<TupleElementKey<I>>(*this); return get_tuple_element_data<TupleElementKey<I>>(*this);
} }
}; };
...@@ -120,7 +128,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -120,7 +128,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// write access // write access
...@@ -128,7 +136,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -128,7 +136,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr auto& At(Number<I>) __host__ __device__ constexpr auto& At(Number<I>)
{ {
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementDataByKey(detail::TupleElementKey<I>{});
} }
// read access // read access
...@@ -158,6 +166,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -158,6 +166,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
}; };
template <>
struct Tuple<>
{
__host__ __device__ constexpr Tuple() = default;
__host__ __device__ static constexpr index_t Size() { return 0; }
template <typename T>
__host__ __device__ constexpr auto operator=(const T&)
{
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
};
template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment