Commit 6f92737f authored by Chao Liu's avatar Chao Liu
Browse files

use remove_cvref_t

parent 21444cfc
...@@ -189,8 +189,7 @@ struct TensorAdaptor ...@@ -189,8 +189,7 @@ struct TensorAdaptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value; return is_known && is_known_at_compile_time<ElementSize>::value;
......
...@@ -185,8 +185,7 @@ struct TensorDescriptor ...@@ -185,8 +185,7 @@ struct TensorDescriptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value && return is_known && is_known_at_compile_time<ElementSize>::value &&
...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
#endif #endif
...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const BThreadBuffer& b_thread_buf, const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<FloatA>>>::value && is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>, is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
remove_cv_t<remove_reference_t<FloatB>>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>, "wrong! inconsistent type");
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 ...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 ...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value, static_assert(is_known_at_compile_time<remove_cvref_t<OriginIdx>>::value,
"wrong! OriginIdx need to be known at compile-time"); "wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time // Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{}; constexpr auto desc = remove_cvref_t<Desc>{};
// OriginIdx is known at compile-time // OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{}); constexpr auto origin_idx = to_multi_index(OriginIdx{});
......
...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value, "wrong! SrcSliceOrigin need to known at compile-time");
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value, "wrong! DstSliceOrigin need to known at compile-time");
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value && is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time // DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
// tensor descriptor for src_vector // tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
// tensor descriptor for dst_vector // tensor descriptor for dst_vector
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -48,7 +48,7 @@ struct Array<TData, 0> ...@@ -48,7 +48,7 @@ struct Array<TData, 0>
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cv_t<remove_reference_t<X>>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
} }
......
...@@ -39,18 +39,15 @@ struct DynamicBuffer ...@@ -39,18 +39,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -67,15 +64,14 @@ struct DynamicBuffer ...@@ -67,15 +64,14 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero< return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value< return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
} }
} }
...@@ -94,18 +90,15 @@ struct DynamicBuffer ...@@ -94,18 +90,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -115,7 +108,7 @@ struct DynamicBuffer ...@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
...@@ -136,70 +129,65 @@ struct DynamicBuffer ...@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
int8_t>::value)
{ {
static_assert( static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x2_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x4_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || (is_same<remove_cvref_t<T>, int8x16_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value),
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), "wrong! not implemented for this combination, please add "
"wrong! not implemented for this combination, please add " "implementation");
"implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) is_same<remove_cvref_t<X>, int8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) is_same<remove_cvref_t<X>, int8x2_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cvref_t<X>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
...@@ -224,18 +212,15 @@ struct DynamicBuffer ...@@ -224,18 +212,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -245,7 +230,7 @@ struct DynamicBuffer ...@@ -245,7 +230,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
......
...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
} }
} // namespace ck } // namespace ck
......
...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return container_reduce( return container_reduce(
Tuple<Ts...>{}, Tuple<Ts...>{},
[](auto x, bool r) { [](auto x, bool r) {
return is_known_at_compile_time< return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
}, },
true); true);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment