Commit 5a9f6308 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Fix with regard to implementing GetZeroVal() in both kernel and host

parent a18e6481
...@@ -82,7 +82,7 @@ struct Max ...@@ -82,7 +82,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::min(); }; __device__ static T GetZeroVal() { return std::numeric_limits<T>::lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -127,16 +127,45 @@ struct Min ...@@ -127,16 +127,45 @@ struct Min
static constexpr bool indexable = true; static constexpr bool indexable = true;
}; };
template <class T>
struct AMax
{
using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const
{
if(a < b)
a = b;
}
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
if(a < b)
{
a = b;
changed = true;
}
}
static constexpr bool indexable = true;
};
template <> template <>
__device__ half_t Max<half_t>::GetZeroVal() __device__ half_t Max<half_t>::GetZeroVal()
{ {
return type_convert<half_t>{}(std::numeric_limits<float>::min()); const unsigned short binary_lowest = 0xFBFF;
return *reinterpret_cast<const half_t*>(&binary_lowest);
}; };
template <> template <>
__device__ half_t Min<half_t>::GetZeroVal() __device__ half_t Min<half_t>::GetZeroVal()
{ {
return type_convert<half_t>{}(std::numeric_limits<float>::max()); const unsigned short binary_max = 0x7BFF;
return *reinterpret_cast<const half_t*>(&binary_max);
}; };
// Unary operators are usually called element-wisely before the reduction is executed on the // Unary operators are usually called element-wisely before the reduction is executed on the
...@@ -281,8 +310,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> ...@@ -281,8 +310,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -292,8 +319,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> ...@@ -292,8 +319,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
using opType = reduce::Mul<T>; using opType = reduce::Mul<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Mul<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Mul<T>::indexable; static constexpr bool indexable = reduce::Mul<T>::indexable;
}; };
...@@ -303,8 +328,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> ...@@ -303,8 +328,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
using opType = reduce::Min<T>; using opType = reduce::Min<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Min<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Min<T>::indexable; static constexpr bool indexable = reduce::Min<T>::indexable;
}; };
...@@ -314,19 +337,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> ...@@ -314,19 +337,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
using opType = reduce::Max<T>; using opType = reduce::Max<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable; static constexpr bool indexable = reduce::Max<T>::indexable;
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
{ {
using opType = reduce::Max<T>; using opType = reduce::AMax<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable; static constexpr bool indexable = reduce::Max<T>::indexable;
}; };
...@@ -336,8 +355,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> ...@@ -336,8 +355,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -347,8 +364,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> ...@@ -347,8 +364,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -358,8 +373,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> ...@@ -358,8 +373,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment