Commit 16effa76 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a91b68df
...@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}") ...@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
# CMAKE_CXX_FLAGS # CMAKE_CXX_FLAGS
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
endif() endif()
......
...@@ -377,7 +377,7 @@ struct RightPad ...@@ -377,7 +377,7 @@ struct RightPad
// at compile-time // at compile-time
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct Embed struct Embed
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
......
...@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform( ...@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform(
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
{ {
......
...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf ...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{ {
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
......
...@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, __host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides) const Tuple<Strides...>& strides)
{ {
......
...@@ -22,24 +22,24 @@ namespace ck { ...@@ -22,24 +22,24 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <
typename FloatA, index_t BlockSize,
typename FloatB, typename FloatA,
typename FloatC, typename FloatB,
typename AKMBlockDesc, typename FloatC,
typename BKNBlockDesc, typename AKMBlockDesc,
index_t M1PerThreadM11, typename BKNBlockDesc,
index_t N1PerThreadN11, index_t M1PerThreadM11,
index_t KPerThread, index_t N1PerThreadN11,
index_t M1N1ThreadClusterM100, index_t KPerThread,
index_t M1N1ThreadClusterN100, index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterM101, index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterN101, index_t M1N1ThreadClusterM101,
index_t AThreadCopyScalarPerVector_M11, index_t M1N1ThreadClusterN101,
index_t BThreadCopyScalarPerVector_N11, index_t AThreadCopyScalarPerVector_M11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && index_t BThreadCopyScalarPerVector_N11,
BKNBlockDesc::IsKnownAtCompileTime(), typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
......
...@@ -38,9 +38,9 @@ template <index_t BlockSize, ...@@ -38,9 +38,9 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...> // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
......
...@@ -21,10 +21,10 @@ template <typename FloatA, ...@@ -21,10 +21,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
...@@ -123,10 +123,10 @@ template <typename FloatA, ...@@ -123,10 +123,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
......
...@@ -19,9 +19,9 @@ template <typename FloatA, ...@@ -19,9 +19,9 @@ template <typename FloatA,
typename CDesc, typename CDesc,
index_t H, index_t H,
index_t W, index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
template <typename ABuffer, template <typename ABuffer,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
template <typename Data, template <typename Data,
typename Desc, typename Desc,
typename SliceLengths, typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceSet_v1 struct ThreadwiseTensorSliceSet_v1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
......
...@@ -57,7 +57,7 @@ template <typename SrcData, ...@@ -57,7 +57,7 @@ template <typename SrcData,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3 struct ThreadwiseTensorSliceTransfer_v1r3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -373,7 +373,7 @@ template <typename SrcData, ...@@ -373,7 +373,7 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -1261,18 +1261,17 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1261,18 +1261,17 @@ struct ThreadwiseTensorSliceTransfer_v3
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, index_t SrcVectorDim,
index_t SrcVectorDim, index_t SrcScalarPerVector,
index_t SrcScalarPerVector, index_t SrcScalarStrideInVector,
index_t SrcScalarStrideInVector, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4 struct ThreadwiseTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
......
...@@ -621,17 +621,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -621,17 +621,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, typename SrcVectorTensorLengths,
typename SrcVectorTensorLengths, typename SrcVectorTensorContiguousDimOrder,
typename SrcVectorTensorContiguousDimOrder, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4r1 struct ThreadwiseTensorSliceTransfer_v4r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
#define CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
template <typename PY, template <typename PY,
typename PX, typename PX,
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false> typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x) __host__ __device__ PY c_style_pointer_cast(PX p_x)
{ {
#pragma clang diagnostic push #pragma clang diagnostic push
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "math.hpp" #include "math.hpp"
#include "number.hpp" #include "number.hpp"
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -38,7 +39,7 @@ struct DynamicBuffer ...@@ -38,7 +39,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
...@@ -93,7 +94,7 @@ struct DynamicBuffer ...@@ -93,7 +94,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
......
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) ...@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{}; return Number<r>{};
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, gcd(ys...)); return gcd(x, gcd(ys...));
...@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y) ...@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -20,10 +21,9 @@ struct TupleElement ...@@ -20,10 +21,9 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template < template <typename T,
typename T, typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
template < template <typename Y,
typename Y, typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename std::enable_if< typename enable_if<sizeof...(Xs) == 1 &&
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value, !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
bool>::type = false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{ {
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_TYPE_HPP #define CK_TYPE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>> ...@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y, template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x) __host__ __device__ constexpr Y as_type(X x)
{ {
union AsType union AsType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment