Commit 16effa76 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a91b68df
...@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}") ...@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
# CMAKE_CXX_FLAGS # CMAKE_CXX_FLAGS
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
endif() endif()
......
...@@ -377,7 +377,7 @@ struct RightPad ...@@ -377,7 +377,7 @@ struct RightPad
// at compile-time // at compile-time
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct Embed struct Embed
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
......
...@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform( ...@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform(
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
{ {
......
...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf ...@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{ {
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
......
...@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, __host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides) const Tuple<Strides...>& strides)
{ {
......
...@@ -22,7 +22,8 @@ namespace ck { ...@@ -22,7 +22,8 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <
index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
...@@ -37,8 +38,7 @@ template <index_t BlockSize, ...@@ -37,8 +38,7 @@ template <index_t BlockSize,
index_t M1N1ThreadClusterN101, index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11, index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11, index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
......
...@@ -38,7 +38,7 @@ template <index_t BlockSize, ...@@ -38,7 +38,7 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...> // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
......
...@@ -21,7 +21,7 @@ template <typename FloatA, ...@@ -21,7 +21,7 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -123,7 +123,7 @@ template <typename FloatA, ...@@ -123,7 +123,7 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
......
...@@ -19,7 +19,7 @@ template <typename FloatA, ...@@ -19,7 +19,7 @@ template <typename FloatA,
typename CDesc, typename CDesc,
index_t H, index_t H,
index_t W, index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
template <typename Data, template <typename Data,
typename Desc, typename Desc,
typename SliceLengths, typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceSet_v1 struct ThreadwiseTensorSliceSet_v1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
......
...@@ -57,7 +57,7 @@ template <typename SrcData, ...@@ -57,7 +57,7 @@ template <typename SrcData,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3 struct ThreadwiseTensorSliceTransfer_v1r3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -373,7 +373,7 @@ template <typename SrcData, ...@@ -373,7 +373,7 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -1261,8 +1261,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -1261,8 +1261,7 @@ struct ThreadwiseTensorSliceTransfer_v3
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
...@@ -1271,7 +1270,7 @@ template < ...@@ -1271,7 +1270,7 @@ template <
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4 struct ThreadwiseTensorSliceTransfer_v4
{ {
......
...@@ -621,8 +621,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -621,8 +621,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
...@@ -630,7 +629,7 @@ template < ...@@ -630,7 +629,7 @@ template <
typename DimAccessOrder, typename DimAccessOrder,
typename SrcVectorTensorLengths, typename SrcVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder, typename SrcVectorTensorContiguousDimOrder,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4r1 struct ThreadwiseTensorSliceTransfer_v4r1
{ {
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
#define CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
template <typename PY, template <typename PY,
typename PX, typename PX,
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false> typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x) __host__ __device__ PY c_style_pointer_cast(PX p_x)
{ {
#pragma clang diagnostic push #pragma clang diagnostic push
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "math.hpp" #include "math.hpp"
#include "number.hpp" #include "number.hpp"
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -38,7 +39,7 @@ struct DynamicBuffer ...@@ -38,7 +39,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
...@@ -93,7 +94,7 @@ struct DynamicBuffer ...@@ -93,7 +94,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
......
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) ...@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{}; return Number<r>{};
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, gcd(ys...)); return gcd(x, gcd(ys...));
...@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y) ...@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -20,9 +21,8 @@ struct TupleElement ...@@ -20,9 +21,8 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template < template <typename T,
typename T, typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
...@@ -58,9 +58,8 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -58,9 +58,8 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
template < template <typename Y,
typename Y, typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value, !is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
...@@ -68,7 +67,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -68,7 +67,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
} }
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename std::enable_if< typename enable_if<sizeof...(Xs) == 1 &&
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value, !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
bool>::type = false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{ {
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_TYPE_HPP #define CK_TYPE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>> ...@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y, template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x) __host__ __device__ constexpr Y as_type(X x)
{ {
union AsType union AsType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment