Commit f1541994 authored by Chao Liu's avatar Chao Liu
Browse files

change clang-format to 5.0

parent 863e069b
...@@ -195,7 +195,7 @@ struct DummyDynamicTransform ...@@ -195,7 +195,7 @@ struct DummyDynamicTransform
carry = do_carry ? 1 : 0; carry = do_carry ? 1 : 0;
#else // negative #else // negative
bool do_borrow = idx_low_tmp < 0; bool do_borrow = idx_low_tmp < 0;
index_t idx_low_new = do_borrow ? idx_low_tmp + idx_low_bound : idx_low_tmp; index_t idx_low_new = do_borrow ? idx_low_tmp + idx_low_bound : idx_low_tmp;
...@@ -304,7 +304,7 @@ struct DummyDynamicTransform ...@@ -304,7 +304,7 @@ struct DummyDynamicTransform
idx_low_diff -= negative_carry; idx_low_diff -= negative_carry;
negative_carry = do_borrow ? 1 : negative_carry; negative_carry = do_borrow ? 1 : negative_carry;
#endif #endif
}; };
...@@ -337,9 +337,9 @@ struct DummyDynamicTransform ...@@ -337,9 +337,9 @@ struct DummyDynamicTransform
const_tmp[i] = p_wei_global[i + 1]; const_tmp[i] = p_wei_global[i + 1];
} }
#else #else
const_tmp[0] = 0; const_tmp[0] = 0;
const_tmp[1] = 2; const_tmp[1] = 2;
const_tmp[2] = 2; const_tmp[2] = 2;
#endif #endif
// initialize idx // initialize idx
...@@ -475,7 +475,7 @@ struct DummyDynamicTransform ...@@ -475,7 +475,7 @@ struct DummyDynamicTransform
// padding check // padding check
bool is_in_bound = idx[3] >= 0 && idx[3] < Hi && idx[4] >= 0 && idx[4] < Wi; bool is_in_bound = idx[3] >= 0 && idx[3] < Hi && idx[4] >= 0 && idx[4] < Wi;
#elif 0 // no pad #elif 0 // no pad
// offset // offset
idx[0] += idx_diff[0]; idx[0] += idx_diff[0];
// C, Y, X // C, Y, X
...@@ -486,7 +486,7 @@ struct DummyDynamicTransform ...@@ -486,7 +486,7 @@ struct DummyDynamicTransform
// padding check // padding check
bool is_in_bound = true; bool is_in_bound = true;
#else // pad #else // pad
// offset // offset
idx[0] += idx_diff[0]; idx[0] += idx_diff[0];
// C, Y, X // C, Y, X
......
...@@ -354,7 +354,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -354,7 +354,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
{ {
#if 1 // debug #if 1 // debug
// input: register to global memory, atomic add // input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::Set ? InMemoryDataOperation::Set
: InMemoryDataOperation::AtomicAdd; : InMemoryDataOperation::AtomicAdd;
......
...@@ -316,14 +316,13 @@ struct DynamicTransformedTensorDescriptor ...@@ -316,14 +316,13 @@ struct DynamicTransformedTensorDescriptor
constexpr bool is_valid_up_always_mapped_to_valid_low = constexpr bool is_valid_up_always_mapped_to_valid_low =
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex(); decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if if constexpr(!is_valid_up_always_mapped_to_valid_low)
constexpr(!is_valid_up_always_mapped_to_valid_low) {
{ const auto up_dims_part = UpDimensionIds{}.At(itran);
const auto up_dims_part = UpDimensionIds{}.At(itran); const auto idx_up_part = pick_array_element(idx_up, up_dims_part);
const auto idx_up_part = pick_array_element(idx_up, up_dims_part);
flag = flag && IsValidUpperIndexMappedToValidLowerIndex(idx_up_part);
flag = flag && IsValidUpperIndexMappedToValidLowerIndex(idx_up_part); }
}
}); });
return flag; return flag;
......
...@@ -267,7 +267,7 @@ struct TensorCoordinate ...@@ -267,7 +267,7 @@ struct TensorCoordinate
private: private:
template <typename... Ts> template <typename... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>) MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{ {
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>( return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>()); make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
...@@ -275,7 +275,7 @@ struct TensorCoordinate ...@@ -275,7 +275,7 @@ struct TensorCoordinate
template <typename... Ts> template <typename... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>) MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{ {
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>( return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>()); make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
......
...@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor, ...@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
index_t... LowerDimensionIds, index_t... LowerDimensionIds,
index_t... UpperDimensionIds> index_t... UpperDimensionIds>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>, Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>, Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>) Sequence<UpperDimensionIds...>)
{ {
return TransformedTensorDescriptor<LowerTensorDescriptor, return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>, Tuple<PassThrough<LowerLengths>...>,
...@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto ...@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto
// reorder a NativeTensorDescriptor // reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
...@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto ...@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto
// reorder a TransformedTensorDescriptor // reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
......
...@@ -210,17 +210,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -210,17 +210,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll #pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(a_block_mtx,
a_block_mtx, p_a_block +
p_a_block + a_block_mtx.GetOffsetFromMultiIndex(
a_block_mtx.GetOffsetFromMultiIndex(k_begin, k_begin, m_repeat * MPerLevel1Cluster) +
m_repeat * MPerLevel1Cluster) + ib * BlockMatrixStrideA + mMyThreadOffsetA,
ib * BlockMatrixStrideA + mMyThreadOffsetA, a_thread_mtx,
a_thread_mtx, p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(
p_a_thread + 0, m_repeat * MPerThreadSubC),
a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths(),
a_thread_sub_mtx.GetLengths(), Number<DataPerReadA>{});
Number<DataPerReadA>{});
} }
} }
...@@ -229,17 +228,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -229,17 +228,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll #pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(b_block_mtx,
b_block_mtx, p_b_block +
p_b_block + b_block_mtx.GetOffsetFromMultiIndex(
b_block_mtx.GetOffsetFromMultiIndex(k_begin, k_begin, n_repeat * NPerLevel1Cluster) +
n_repeat * NPerLevel1Cluster) + ib * BlockMatrixStrideB + mMyThreadOffsetB,
ib * BlockMatrixStrideB + mMyThreadOffsetB, b_thread_mtx,
b_thread_mtx, p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(
p_b_thread + 0, n_repeat * NPerThreadSubC),
b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC), b_thread_sub_mtx.GetLengths(),
b_thread_sub_mtx.GetLengths(), Number<DataPerReadB>{});
Number<DataPerReadB>{});
} }
} }
...@@ -391,9 +389,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -391,9 +389,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
p_c_thread + p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex(
c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster, m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster),
n_repeat * NPerLevel1Cluster),
c_block_mtx, c_block_mtx,
p_c_block + p_c_block +
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster, c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
...@@ -405,5 +402,5 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -405,5 +402,5 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
}; };
} // namespace } // namespace ck
#endif #endif
...@@ -336,9 +336,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -336,9 +336,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_if<MRepeat == 2 && NRepeat == 2>{}([&](auto) { static_if<MRepeat == 2 && NRepeat == 2>{}(
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); [&](auto) { Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); })
}).Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); }); .Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); });
#else #else
Run_naive(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_block, p_c_thread);
#endif #endif
......
...@@ -153,9 +153,8 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -153,9 +153,8 @@ struct ThreadwiseGemmTransANormalBNormalC
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) || (is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{})); (is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
static_if<has_amd_asm>{}([&](auto fwd) { static_if<has_amd_asm>{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); })
Run_amd_asm(p_a, p_b, fwd(p_c)); .Else([&](auto) { Run_source(p_a, p_b, p_c); });
}).Else([&](auto) { Run_source(p_a, p_b, p_c); });
#else #else
Run_source(p_a, p_b, p_c); Run_source(p_a, p_b, p_c);
#endif #endif
......
...@@ -82,91 +82,95 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -82,91 +82,95 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}([&]( ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
auto long_vector_access_id) { [&](auto long_vector_access_id) {
// data id w.r.t slicing-window // data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id; auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) = long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim]; long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector // buffer to hold a src long-vector
SrcData p_src_long_vector[long_vector_size]; SrcData p_src_long_vector[long_vector_size];
#if 1 #if 1
// zero out buffer // zero out buffer
for(index_t i = 0; i < long_vector_size; ++i) for(index_t i = 0; i < long_vector_size; ++i)
{ {
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
#endif #endif
// load data from src to the long-vector buffer // load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access; scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access; const index_t buffer_offset = i * src_data_per_access;
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto src_coord =
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // Check src data's valid mapping situation, only check the first data in this
// has the valid/invalid mapping situation // src
transfer_data<SrcData, // vector. It's user's responsiblity to make sure all data in the src vector
SrcDataPerRead, // has the valid/invalid mapping situation
SrcAddressSpace, transfer_data<SrcData,
AddressSpace::Vgpr, SrcDataPerRead,
InMemoryDataOperation::Set, SrcAddressSpace,
SrcDataStride, AddressSpace::Vgpr,
1>(p_src, InMemoryDataOperation::Set,
src_coord.GetOffset(), SrcDataStride,
src_coord.IsOffsetValidAssumingUpperIndexIsValid(), 1>(p_src,
SrcDesc::GetElementSpace(), src_coord.GetOffset(),
p_src_long_vector, src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
buffer_offset, SrcDesc::GetElementSpace(),
true, p_src_long_vector,
long_vector_size); buffer_offset,
} true,
long_vector_size);
// SrcData to DstData conversion }
DstData p_dst_long_vector[long_vector_size];
// SrcData to DstData conversion
for(index_t i = 0; i < long_vector_size; ++i) DstData p_dst_long_vector[long_vector_size];
{
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]); for(index_t i = 0; i < long_vector_size; ++i)
} {
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
// store data from the long-vector buffer to dst }
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
{ // store data from the long-vector buffer to dst
auto scalar_id = make_zero_array<index_t, nDim>(); for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
scalar_id(vector_access_dim) = i * dst_data_per_access; {
auto scalar_id = make_zero_array<index_t, nDim>();
const index_t buffer_offset = i * dst_data_per_access; scalar_id(vector_access_dim) = i * dst_data_per_access;
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); const index_t buffer_offset = i * dst_data_per_access;
// Check dst data's valid mapping situation, only check the first data in this dst const auto dst_coord =
// vector. It's user's responsiblity to make sure all data in the dst vector mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// has the valid/invalid mapping situation
transfer_data<DstData, // Check dst data's valid mapping situation, only check the first data in this
DstDataPerWrite, // dst
AddressSpace::Vgpr, // vector. It's user's responsiblity to make sure all data in the dst vector
DstAddressSpace, // has the valid/invalid mapping situation
DstInMemOp, transfer_data<DstData,
1, DstDataPerWrite,
DstDataStride>(p_dst_long_vector, AddressSpace::Vgpr,
buffer_offset, DstAddressSpace,
true, DstInMemOp,
long_vector_size, 1,
p_dst, DstDataStride>(p_dst_long_vector,
dst_coord.GetOffset(), buffer_offset,
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(), true,
DstDesc::GetElementSpace()); long_vector_size,
} p_dst,
}); dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
}
});
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -175,9 +179,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -175,9 +179,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
const auto step_sizes = to_array(step_sizes_); const auto step_sizes = to_array(step_sizes_);
static_if<PositiveDirection>{}([&](auto) { static_if<PositiveDirection>{}([&](auto) { mSrcSliceOrigin += to_array(step_sizes); })
mSrcSliceOrigin += to_array(step_sizes); .Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -186,9 +189,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -186,9 +189,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
const auto step_sizes = to_array(step_sizes_); const auto step_sizes = to_array(step_sizes_);
static_if<PositiveDirection>{}([&](auto) { static_if<PositiveDirection>{}([&](auto) { mDstSliceOrigin += step_sizes; })
mDstSliceOrigin += step_sizes; .Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
private: private:
......
...@@ -28,7 +28,7 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -28,7 +28,7 @@ void device_dummy_dynamic_transform(InDesc,
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
const auto in_nchw_desc = make_dynamic_native_tensor_descriptor(to_array(InDesc::GetLengths()), const auto in_nchw_desc = make_dynamic_native_tensor_descriptor(to_array(InDesc::GetLengths()),
to_array(InDesc::GetStrides())); to_array(InDesc::GetStrides()));
const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor( const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor(
to_array(WeiDesc::GetLengths()), to_array(WeiDesc::GetStrides())); to_array(WeiDesc::GetLengths()), to_array(WeiDesc::GetStrides()));
......
...@@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, ...@@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t ho = HoPerTile * htile + j; std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i) for(int i = 0; i < WoPerTile; ++i)
{ {
std::size_t wo = WoPerTile * wtile + i; std::size_t wo = WoPerTile * wtile + i;
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
} }
} }
......
...@@ -508,8 +508,8 @@ template <bool B> ...@@ -508,8 +508,8 @@ template <bool B>
struct bool_type : std::integral_constant<bool, B> struct bool_type : std::integral_constant<bool, B>
{ {
}; };
using std::true_type;
using std::false_type; using std::false_type;
using std::true_type;
/// Type traits for floating-point types. /// Type traits for floating-point types.
template <typename T> template <typename T>
...@@ -854,8 +854,8 @@ inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, ...@@ -854,8 +854,8 @@ inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y,
((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) ||
((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); ((z & 0x7FFF) > 0x7C00 && !(z & 0x200)));
#endif #endif
return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200)
: (z | 0x200); : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200);
} }
/// Select value or signaling NaN. /// Select value or signaling NaN.
...@@ -1756,9 +1756,9 @@ uint32 mulhi(uint32 x, uint32 y) ...@@ -1756,9 +1756,9 @@ uint32 mulhi(uint32 x, uint32 y)
uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16),
c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16);
return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) +
((R == std::round_to_nearest) ? ((c >> 15) & 1) : (R == std::round_toward_infinity) ((R == std::round_to_nearest)
? ((c & 0xFFFF) != 0) ? ((c >> 15) & 1)
: 0); : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0);
} }
/// 64-bit multiplication. /// 64-bit multiplication.
...@@ -2247,7 +2247,7 @@ unsigned int area(unsigned int arg) ...@@ -2247,7 +2247,7 @@ unsigned int area(unsigned int arg)
{ {
if(expy < 0) if(expy < 0)
{ {
r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | r = 0x40000000 + ((expy > -30) ? ((r >> -expy) |
((r & ((static_cast<uint32>(1) << -expy) - 1)) != 0)) ((r & ((static_cast<uint32>(1) << -expy) - 1)) != 0))
: 1); : 1);
expy = 0; expy = 0;
...@@ -2379,10 +2379,12 @@ unsigned int erf(unsigned int arg) ...@@ -2379,10 +2379,12 @@ unsigned int erf(unsigned int arg)
t / t /
((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0)
: f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp)));
return (!C || sign) ? fixed2half<R, 31, false, true, true>( return (!C || sign)
0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) ? fixed2half<R, 31, false, true, true>(
: (e.exp < -25) ? underflow<R>() : fixed2half<R, 30, false, false, true>( 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U))
e.m >> 1, e.exp + 14, 0, e.m & 1); : (e.exp < -25)
? underflow<R>()
: fixed2half<R, 30, false, false, true>(e.m >> 1, e.exp + 14, 0, e.m & 1);
} }
/// Gamma function and postprocessing. /// Gamma function and postprocessing.
...@@ -2402,8 +2404,7 @@ unsigned int gamma(unsigned int arg) ...@@ -2402,8 +2404,7 @@ unsigned int gamma(unsigned int arg)
for(unsigned int i=0; i<5; ++i) for(unsigned int i=0; i<5; ++i)
s += p[i+1] / (arg+i); s += p[i+1] / (arg+i);
return std::log(s) + (arg-0.5)*std::log(t) - t; return std::log(s) + (arg-0.5)*std::log(t) - t;
*/ static const f31 */ static const f31 pi(0xC90FDAA2, 1),
pi(0xC90FDAA2, 1),
lbe(0xB8AA3B29, 0); lbe(0xB8AA3B29, 0);
unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000;
bool bsign = sign != 0; bool bsign = sign != 0;
...@@ -2490,7 +2491,7 @@ unsigned int gamma(unsigned int arg) ...@@ -2490,7 +2491,7 @@ unsigned int gamma(unsigned int arg)
{ {
if(z.exp < 0) if(z.exp < 0)
s = s * z; s = s * z;
s = pi / s; s = pi / s;
if(s.exp < -24) if(s.exp < -24)
return underflow<R>(sign); return underflow<R>(sign);
} }
...@@ -2789,7 +2790,7 @@ inline half operator"" _h(long double value) ...@@ -2789,7 +2790,7 @@ inline half operator"" _h(long double value)
{ {
return half(detail::binary, detail::float2half<half::round_style>(value)); return half(detail::binary, detail::float2half<half::round_style>(value));
} }
} } // namespace literal
#endif #endif
namespace detail { namespace detail {
...@@ -2837,8 +2838,8 @@ struct half_caster<half, half, R> ...@@ -2837,8 +2838,8 @@ struct half_caster<half, half, R>
{ {
static half cast(half arg) { return arg; } static half cast(half arg) { return arg; }
}; };
} } // namespace detail
} } // namespace half_float
/// Extensions to the C++ standard library. /// Extensions to the C++ standard library.
namespace std { namespace std {
...@@ -3003,7 +3004,7 @@ struct hash<half_float::half> ...@@ -3003,7 +3004,7 @@ struct hash<half_float::half>
} }
}; };
#endif #endif
} } // namespace std
namespace half_float { namespace half_float {
/// \anchor compop /// \anchor compop
...@@ -3122,13 +3123,14 @@ inline half operator+(half x, half y) ...@@ -3122,13 +3123,14 @@ inline half operator+(half x, half y)
return half(detail::binary, return half(detail::binary,
(absx > 0x7C00 || absy > 0x7C00) (absx > 0x7C00 || absy > 0x7C00)
? detail::signal(x.data_, y.data_) ? detail::signal(x.data_, y.data_)
: (absy != 0x7C00) ? x.data_ : (sub && absx == 0x7C00) ? detail::invalid() : (absy != 0x7C00) ? x.data_
: y.data_); : (sub && absx == 0x7C00) ? detail::invalid() : y.data_);
if(!absx) if(!absx)
return absy ? y : half(detail::binary, return absy ? y
(half::round_style == std::round_toward_neg_infinity) : half(detail::binary,
? (x.data_ | y.data_) (half::round_style == std::round_toward_neg_infinity)
: (x.data_ & y.data_)); ? (x.data_ | y.data_)
: (x.data_ & y.data_));
if(!absy) if(!absy)
return x; return x;
unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000;
...@@ -3449,10 +3451,11 @@ inline half fma(half x, half y, half z) ...@@ -3449,10 +3451,11 @@ inline half fma(half x, half y, half z)
: (sign | 0x7C00)) : (sign | 0x7C00))
: z; : z;
if(!absx || !absy) if(!absx || !absy)
return absz ? z : half(detail::binary, return absz
(half::round_style == std::round_toward_neg_infinity) ? z
? (z.data_ | sign) : half(detail::binary,
: (z.data_ & sign)); (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign)
: (z.data_ & sign));
for(; absx < 0x400; absx <<= 1, --exp) for(; absx < 0x400; absx <<= 1, --exp)
; ;
for(; absy < 0x400; absy <<= 1, --exp) for(; absy < 0x400; absy <<= 1, --exp)
...@@ -3516,9 +3519,8 @@ inline half fma(half x, half y, half z) ...@@ -3516,9 +3519,8 @@ inline half fma(half x, half y, half z)
inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) inline HALF_CONSTEXPR_NOERR half fmax(half x, half y)
{ {
return half(detail::binary, return half(detail::binary,
(!isnan(y) && (isnan(x) || (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <
(x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15))))))
(y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15))))))
? detail::select(y.data_, x.data_) ? detail::select(y.data_, x.data_)
: detail::select(x.data_, y.data_)); : detail::select(x.data_, y.data_));
} }
...@@ -3533,9 +3535,8 @@ inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) ...@@ -3533,9 +3535,8 @@ inline HALF_CONSTEXPR_NOERR half fmax(half x, half y)
inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) inline HALF_CONSTEXPR_NOERR half fmin(half x, half y)
{ {
return half(detail::binary, return half(detail::binary,
(!isnan(y) && (isnan(x) || (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) >
(x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15))))))
(y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15))))))
? detail::select(y.data_, x.data_) ? detail::select(y.data_, x.data_)
: detail::select(x.data_, y.data_)); : detail::select(x.data_, y.data_));
} }
...@@ -3886,9 +3887,9 @@ inline half log1p(half arg) ...@@ -3886,9 +3887,9 @@ inline half log1p(half arg)
#else #else
if(arg.data_ >= 0xBC00) if(arg.data_ >= 0xBC00)
return half(detail::binary, return half(detail::binary,
(arg.data_ == 0xBC00) ? detail::pole(0x8000) : (arg.data_ <= 0xFC00) (arg.data_ == 0xBC00)
? detail::invalid() ? detail::pole(0x8000)
: detail::signal(arg.data_)); : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_));
int abs = arg.data_ & 0x7FFF, exp = -15; int abs = arg.data_ & 0x7FFF, exp = -15;
if(!abs || abs >= 0x7C00) if(!abs || abs >= 0x7C00)
return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg;
...@@ -4395,7 +4396,7 @@ inline half cos(half arg) ...@@ -4395,7 +4396,7 @@ inline half cos(half arg)
if(half::round_style != std::round_to_nearest && abs == 0x598C) if(half::round_style != std::round_to_nearest && abs == 0x598C)
return half(detail::binary, detail::rounded<half::round_style, true>(0x80FC, 1, 1)); return half(detail::binary, detail::rounded<half::round_style, true>(0x80FC, 1, 1));
std::pair<detail::uint32, detail::uint32> sc = detail::sincos(detail::angle_arg(abs, k), 28); std::pair<detail::uint32, detail::uint32> sc = detail::sincos(detail::angle_arg(abs, k), 28);
detail::uint32 sign = -static_cast<detail::uint32>(((k >> 1) ^ k) & 1); detail::uint32 sign = -static_cast<detail::uint32>(((k >> 1) ^ k) & 1);
return half(detail::binary, return half(detail::binary,
detail::fixed2half<half::round_style, 30, true, true, true>( detail::fixed2half<half::round_style, 30, true, true, true>(
(((k & 1) ? sc.first : sc.second) ^ sign) - sign)); (((k & 1) ? sc.first : sc.second) ^ sign) - sign));
...@@ -4439,7 +4440,7 @@ inline half tan(half arg) ...@@ -4439,7 +4440,7 @@ inline half tan(half arg)
} }
std::pair<detail::uint32, detail::uint32> sc = detail::sincos(detail::angle_arg(abs, k), 30); std::pair<detail::uint32, detail::uint32> sc = detail::sincos(detail::angle_arg(abs, k), 30);
if(k & 1) if(k & 1)
sc = std::make_pair(-sc.second, sc.first); sc = std::make_pair(-sc.second, sc.first);
detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second);
detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx;
for(; my < 0x80000000; my <<= 1, --exp) for(; my < 0x80000000; my <<= 1, --exp)
...@@ -4517,7 +4518,7 @@ inline half acos(half arg) ...@@ -4517,7 +4518,7 @@ inline half acos(half arg)
? detail::invalid() ? detail::invalid()
: sign ? detail::rounded<half::round_style, true>(0x4248, 0, 1) : 0); : sign ? detail::rounded<half::round_style, true>(0x4248, 0, 1) : 0);
std::pair<detail::uint32, detail::uint32> cs = detail::atan2_args(abs); std::pair<detail::uint32, detail::uint32> cs = detail::atan2_args(abs);
detail::uint32 m = detail::atan2(cs.second, cs.first, 28); detail::uint32 m = detail::atan2(cs.second, cs.first, 28);
return half(detail::binary, return half(detail::binary,
detail::fixed2half<half::round_style, 31, false, true, true>( detail::fixed2half<half::round_style, 31, false, true, true>(
sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); sign ? (0xC90FDAA2 - m) : m, 15, 0, sign));
...@@ -5354,13 +5355,13 @@ inline HALF_CONSTEXPR half copysign(half x, half y) ...@@ -5354,13 +5355,13 @@ inline HALF_CONSTEXPR half copysign(half x, half y)
/// \retval FP_NORMAL for all other (normal) values /// \retval FP_NORMAL for all other (normal) values
inline HALF_CONSTEXPR int fpclassify(half arg) inline HALF_CONSTEXPR int fpclassify(half arg)
{ {
return !(arg.data_ & 0x7FFF) ? FP_ZERO : ((arg.data_ & 0x7FFF) < 0x400) return !(arg.data_ & 0x7FFF)
? FP_SUBNORMAL ? FP_ZERO
: ((arg.data_ & 0x7FFF) < 0x7C00) : ((arg.data_ & 0x7FFF) < 0x400)
? FP_NORMAL ? FP_SUBNORMAL
: ((arg.data_ & 0x7FFF) == 0x7C00) : ((arg.data_ & 0x7FFF) < 0x7C00)
? FP_INFINITE ? FP_NORMAL
: FP_NAN; : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN;
} }
/// Check if finite number. /// Check if finite number.
...@@ -5652,7 +5653,7 @@ inline void fethrowexcept(int excepts, const char* msg = "") ...@@ -5652,7 +5653,7 @@ inline void fethrowexcept(int excepts, const char* msg = "")
throw std::range_error(msg); throw std::range_error(msg);
} }
/// \} /// \}
} } // namespace half_float
#undef HALF_UNUSED_NOERR #undef HALF_UNUSED_NOERR
#undef HALF_CONSTEXPR #undef HALF_CONSTEXPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment