Commit f0019df3 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator

parent eac1753d
...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
const auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
const auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)origReduceLen; (void)origReduceLen;
const auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)origReduceLen; (void)origReduceLen;
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetZeroVal();
// LDS // LDS
__shared__ compType p_in_block_values_buffer[BlockBufferSize]; __shared__ compType p_in_block_values_buffer[BlockBufferSize];
......
...@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion ...@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
}; };
template <typename T> template <typename T>
struct NumericLimits; struct NumericLimits
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
};
template <> template <>
struct NumericLimits<int32_t> struct NumericLimits<half_t>
{ {
__host__ __device__ static constexpr int32_t Min() static constexpr unsigned short binary_min = 0x0400;
{ static constexpr unsigned short binary_max = 0x7BFF;
return std::numeric_limits<int32_t>::min(); static constexpr unsigned short binary_lowest = 0xFBFF;
}
__host__ __device__ static constexpr int32_t Max() __host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
{
return std::numeric_limits<int32_t>::max(); __host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
}
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
}; };
} // namespace ck } // namespace ck
......
...@@ -58,7 +58,7 @@ struct Add ...@@ -58,7 +58,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); }; __device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
...@@ -70,7 +70,7 @@ struct Mul ...@@ -70,7 +70,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(1.0f); }; __device__ static constexpr T GetZeroVal() { return static_cast<T>(1.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
...@@ -82,7 +82,7 @@ struct Max ...@@ -82,7 +82,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::lowest(); }; __device__ static constexpr T GetZeroVal() { return NumericLimits<T>::lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -107,7 +107,7 @@ struct Min ...@@ -107,7 +107,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::max(); }; __device__ static constexpr T GetZeroVal() { return NumericLimits<T>::Max(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -132,7 +132,7 @@ struct AMax ...@@ -132,7 +132,7 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); }; __device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -152,22 +152,6 @@ struct AMax ...@@ -152,22 +152,6 @@ struct AMax
static constexpr bool indexable = true; static constexpr bool indexable = true;
}; };
template <>
__device__ half_t Max<half_t>::GetZeroVal()
{
const unsigned short binary_lowest = 0xFBFF;
return *reinterpret_cast<const half_t*>(&binary_lowest);
};
template <>
__device__ half_t Min<half_t>::GetZeroVal()
{
const unsigned short binary_max = 0x7BFF;
return *reinterpret_cast<const half_t*>(&binary_max);
};
// Unary operators are usually called element-wisely before the reduction is executed on the // Unary operators are usually called element-wisely before the reduction is executed on the
// elements. // elements.
// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 // They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment