Commit f0019df3 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator

parent eac1753d
......@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
......@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(void)ws_indices_global;
(void)indices_global;
const auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
(void)ws_indices_global;
const auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{
(void)origReduceLen;
const auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
......@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(void)ws_indices_global;
(void)indices_global;
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
(void)ws_indices_global;
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
......@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{
(void)origReduceLen;
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
......@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused
(void)beta; // unused
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
......@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused
(void)beta; // unused
auto zeroVal = opReduce::GetZeroVal();
constexpr auto zeroVal = opReduce::GetZeroVal();
// LDS
__shared__ compType p_in_block_values_buffer[BlockBufferSize];
......
......@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
};
template <typename T>
struct NumericLimits;
struct NumericLimits
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
};
template <>
struct NumericLimits<int32_t>
struct NumericLimits<half_t>
{
__host__ __device__ static constexpr int32_t Min()
{
return std::numeric_limits<int32_t>::min();
}
static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;
__host__ __device__ static constexpr int32_t Max()
{
return std::numeric_limits<int32_t>::max();
}
__host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
};
} // namespace ck
......
......@@ -58,7 +58,7 @@ struct Add
{
using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); };
__device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
......@@ -70,7 +70,7 @@ struct Mul
{
using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(1.0f); };
__device__ static constexpr T GetZeroVal() { return static_cast<T>(1.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
......@@ -82,7 +82,7 @@ struct Max
{
using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::lowest(); };
__device__ static constexpr T GetZeroVal() { return NumericLimits<T>::lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const
{
......@@ -107,7 +107,7 @@ struct Min
{
using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::max(); };
__device__ static constexpr T GetZeroVal() { return NumericLimits<T>::Max(); };
__device__ inline constexpr void operator()(T& a, T b) const
{
......@@ -132,7 +132,7 @@ struct AMax
{
using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); };
__device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const
{
......@@ -152,22 +152,6 @@ struct AMax
static constexpr bool indexable = true;
};
template <>
__device__ half_t Max<half_t>::GetZeroVal()
{
const unsigned short binary_lowest = 0xFBFF;
return *reinterpret_cast<const half_t*>(&binary_lowest);
};
template <>
__device__ half_t Min<half_t>::GetZeroVal()
{
const unsigned short binary_max = 0x7BFF;
return *reinterpret_cast<const half_t*>(&binary_max);
};
// Unary operators are usually called element-wisely before the reduction is executed on the
// elements.
// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment