Commit 5a9f6308 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Fix with regard to implementing GetZeroVal() in both kernel and host

parent a18e6481
......@@ -82,7 +82,7 @@ struct Max
{
using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::min(); };
__device__ static T GetZeroVal() { return std::numeric_limits<T>::lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const
{
......@@ -127,16 +127,45 @@ struct Min
static constexpr bool indexable = true;
};
template <class T>
struct AMax
{
using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const
{
if(a < b)
a = b;
}
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
if(a < b)
{
a = b;
changed = true;
}
}
static constexpr bool indexable = true;
};
template <>
__device__ half_t Max<half_t>::GetZeroVal()
{
return type_convert<half_t>{}(std::numeric_limits<float>::min());
const unsigned short binary_lowest = 0xFBFF;
return *reinterpret_cast<const half_t*>(&binary_lowest);
};
template <>
__device__ half_t Min<half_t>::GetZeroVal()
{
return type_convert<half_t>{}(std::numeric_limits<float>::max());
const unsigned short binary_max = 0x7BFF;
return *reinterpret_cast<const half_t*>(&binary_max);
};
// Unary operators are usually called element-wisely before the reduction is executed on the
......@@ -281,8 +310,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
using opType = reduce::Add<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable;
};
......@@ -292,8 +319,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
using opType = reduce::Mul<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Mul<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Mul<T>::indexable;
};
......@@ -303,8 +328,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
using opType = reduce::Min<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Min<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Min<T>::indexable;
};
......@@ -314,19 +337,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
using opType = reduce::Max<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
{
using opType = reduce::Max<T>;
using opType = reduce::AMax<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable;
};
......@@ -336,8 +355,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
using opType = reduce::Add<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable;
};
......@@ -347,8 +364,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
using opType = reduce::Add<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable;
};
......@@ -358,8 +373,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
using opType = reduce::Add<T>;
using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment