Commit 52423948 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'master' into jd_redux

parents b97af4ec 98a2cfcc
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "Sequence.hpp"
#include "sequence.hpp"
#include "functional2.hpp"
namespace ck {
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
struct Array
{
using Type = Array<TData, NSize>;
using type = Array<TData, NSize>;
using data_type = TData;
static constexpr index_t nSize = NSize;
index_t mData[NSize];
index_t mData[nSize];
__host__ __device__ explicit constexpr Array() {}
template <class... Xs>
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
template <typename X, typename... Xs>
__host__ __device__ constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
}
__host__ __device__ static constexpr index_t GetSize() { return NSize; }
__host__ __device__ static constexpr index_t Size() { return NSize; }
// TODO: remove
__host__ __device__ static constexpr index_t GetSize() { return Size(); }
template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const
__host__ __device__ constexpr const TData& At(Number<I>) const
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
template <index_t I>
__host__ __device__ TData& operator()(Number<I>)
__host__ __device__ constexpr TData& At(Number<I>)
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
template <index_t I>
__host__ __device__ constexpr void Set(Number<I>, TData x)
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
template <typename I>
__host__ __device__ constexpr const TData& operator[](I i) const
{
static_assert(I < NSize, "wrong!");
return At(i);
}
mData[I] = x;
template <typename I>
__host__ __device__ constexpr TData& operator()(I i)
{
return At(i);
}
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
template <typename T>
__host__ __device__ constexpr type& operator=(const T& x)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; });
return *this;
}
struct lambda_PushBack // emulate constexpr lambda
{
......@@ -63,7 +82,7 @@ struct Array
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
new_array.Set(Number<I>{}, old_array[I]);
new_array(Number<I>{}) = old_array[I];
}
};
......@@ -73,19 +92,96 @@ struct Array
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
new_array.Set(Number<NSize>{}, x);
new_array(Number<NSize>{}) = x;
return new_array;
}
};
// Arr: Array
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
template <typename I>
__host__ __device__ constexpr const data_type& operator[](I i) const
{
return At(i);
}
template <typename I>
__host__ __device__ constexpr data_type& operator()(I i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr type& operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
Arr& mArray;
};
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
// TODO: remove this
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
return Array<index_t, sizeof...(Is)>{Is...};
}
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
......@@ -93,7 +189,7 @@ __host__ __device__ constexpr auto make_zero_array()
return zero_array;
}
template <class TData, index_t NSize, index_t... IRs>
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/)
{
......@@ -104,7 +200,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return Array<TData, NSize>{old_array[IRs]...};
}
template <class TData, index_t NSize, class MapOld2New>
template <typename TData, index_t NSize, typename MapOld2New>
struct lambda_reorder_array_given_old2new
{
const Array<TData, NSize>& old_array;
......@@ -121,13 +217,13 @@ struct lambda_reorder_array_given_old2new
{
TData old_data = old_array[IOldDim];
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
constexpr index_t INewDim = MapOld2New::At(Number<IOldDim>{});
new_array.Set(Number<INewDim>{}, old_data);
new_array(Number<INewDim>{}) = old_data;
}
};
template <class TData, index_t NSize, index_t... IRs>
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*old2new*/)
{
......@@ -143,7 +239,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array;
}
template <class TData, index_t NSize, class ExtractSeq>
template <typename TData, index_t NSize, typename ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{
Array<TData, ExtractSeq::GetSize()> new_array;
......@@ -152,12 +248,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
return new_array;
}
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
// emulate constepxr lambda for array
template <typename F, typename X, typename Y, typename Z>
struct lambda_array_math
{
const F& f;
......@@ -174,13 +271,12 @@ struct lambda_array_math
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
z.Set(IDim, f(x[IDim], y[IDim]));
z(IDim) = f(x[IDim], y[IDim]);
}
};
// Array = Array + Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{
Array<TData, NSize> result;
......@@ -195,7 +291,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
}
// Array = Array - Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{
Array<TData, NSize> result;
......@@ -210,7 +306,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
}
// Array += Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a + b;
......@@ -218,14 +314,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
}
// Array -= Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a - b;
return a;
}
// Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -242,7 +338,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
}
// Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -259,7 +355,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
}
// Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -276,7 +372,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
}
// Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -292,7 +388,21 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return result;
}
template <class TData, index_t NSize, class Reduce>
// Array = Array * TData
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator*(TData v, Array<TData, NSize> a)
{
Array<TData, NSize> result;
for(index_t i = 0; i < NSize; ++i)
{
result(i) = a[i] * v;
}
return result;
}
template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{
......@@ -305,89 +415,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return result;
}
template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "array.hpp"
namespace ck {
template <index_t NSize>
__host__ __device__ void print_array(const char* s, Array<uint32_t, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
template <index_t NSize>
__host__ __device__ void print_array(const char* s, Array<int32_t, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %d, {%d}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %d, {%d %d}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %d, {%d %d %d}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %d, {%d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
......@@ -4,16 +4,26 @@
#include "config.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
#include "Array.hpp"
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_INTRINSIC
#include "amd_intrinsic.hpp"
#endif
#endif
......@@ -4,29 +4,47 @@
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INTRINSIC 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
enum address_space_t
{
generic = 0,
global = 3
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, HIP compiler need this definition to generate optimal load and store
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
using index_t = uint32_t;
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
// data type conversion
template <typename T>
struct type_convert
{
d += s0 * s1;
}
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
} // namespace ck
......
......@@ -6,17 +6,30 @@
#include "nvToolsExt.h"
#include "helper_cuda.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_USE_AMD_INTRINSIC 0
#define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
enum address_space_t
{
generic = 0,
global = generic
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct
......@@ -24,7 +37,16 @@ namespace ck {
using float2_t = float2;
using float4_t = float4;
using index_t = uint32_t;
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
......
......@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
#include "type.hpp"
namespace ck {
// TODO: right? wrong?
struct forwarder
{
template <typename T>
......@@ -17,12 +19,30 @@ struct forwarder
struct swallow
{
template <class... Ts>
template <typename... Ts>
__host__ __device__ constexpr swallow(Ts&&...)
{
}
};
template <typename T>
struct logical_and
{
constexpr bool operator()(const T& x, const T& y) const { return x && y; }
};
template <typename T>
struct logical_or
{
constexpr bool operator()(const T& x, const T& y) const { return x || y; }
};
template <typename T>
struct logical_not
{
constexpr bool operator()(const T& x) const { return !x; }
};
// Emulate if constexpr
template <bool>
struct static_if;
......@@ -32,7 +52,7 @@ struct static_if<true>
{
using Type = static_if<true>;
template <class F>
template <typename F>
__host__ __device__ constexpr auto operator()(F f) const
{
// This is a trick for compiler:
......@@ -43,7 +63,7 @@ struct static_if<true>
return Type{};
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Else(F)
{
return Type{};
......@@ -55,13 +75,13 @@ struct static_if<false>
{
using Type = static_if<false>;
template <class F>
template <typename F>
__host__ __device__ constexpr auto operator()(F) const
{
return Type{};
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Else(F f)
{
// This is a trick for compiler:
......@@ -73,5 +93,23 @@ struct static_if<false>
}
};
template <bool predicate, class X, class Y>
struct conditional;
template <class X, class Y>
struct conditional<true, X, Y>
{
using type = X;
};
template <class X, class Y>
struct conditional<false, X, Y>
{
using type = Y;
};
template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck
#endif
......@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
namespace ck {
namespace detail {
template <class>
struct static_for_impl;
......@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
}
};
} // namespace detail
// F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
......@@ -33,38 +37,10 @@ struct static_for
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f);
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
template <class Seq, class Reduce>
struct lambda_accumulate_on_sequence
{
const Reduce& f;
index_t& result;
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
: f(f_), result(result_)
{
}
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return result = f(result, Seq::Get(IDim{}));
}
};
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
return result;
}
} // namespace ck
#endif
......@@ -3,25 +3,12 @@
#include "functional.hpp"
#include "functional2.hpp"
#include "Sequence.hpp"
#include "Array.hpp"
#include "sequence.hpp"
#include "array.hpp"
namespace ck {
template <class>
struct is_static : integral_constant<bool, false>
{
};
template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
namespace detail {
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
......@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
......@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
}
};
} // namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
......@@ -139,7 +128,8 @@ struct ford
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i});
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Array<index_t, 1>{i});
}
}
};
......
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "sequence.hpp"
#include "tuple.hpp"
#include "array.hpp"
namespace ck {
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const
{
return f(x.At(Number<Is>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x)
{
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
}
} // namespace ck
#endif
......@@ -13,51 +13,5 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
};
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
......@@ -3,6 +3,7 @@
#include "config.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
namespace ck {
namespace math {
......@@ -31,6 +32,12 @@ struct multiplies
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
template <class T>
struct maxer
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <class T>
struct integer_divide_ceiler
{
......@@ -98,6 +105,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return max(x, xs...);
}
template <class T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <class T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
} // namespace math
} // namspace ck
......
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
......@@ -2,29 +2,34 @@
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace ck {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct Sequence;
template <class Seq, index_t I>
template <typename Seq, index_t I>
struct sequence_split;
template <class>
template <typename>
struct sequence_reverse;
template <class>
template <typename>
struct sequence_map_inverse;
template <class>
template <typename>
struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq>
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is>
......@@ -35,9 +40,11 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
__host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
__host__ __device__ static constexpr auto GetSize() { return Size(); }
__host__ __device__ static constexpr index_t GetImpl(index_t I)
__host__ __device__ static constexpr index_t At(index_t I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0};
......@@ -45,23 +52,24 @@ struct Sequence
}
template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>)
__host__ __device__ static constexpr auto At(Number<I>)
{
static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
return Number<At(I)>{};
}
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
__host__ __device__ static constexpr auto Get(Number<I>)
{
return Get(Number<I>{});
return At(Number<I>{});
}
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <typename I>
__host__ __device__ constexpr auto operator[](I i) const
{
return At(i);
}
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
......@@ -71,14 +79,14 @@ struct Sequence
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::Get(Number<IRs>{})...>{};
return Sequence<Type::At(Number<IRs>{})...>{};
}
// MapOld2New is Sequence<...>
template <class MapOld2New>
template <typename MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::GetSize() == GetSize(),
static_assert(MapOld2New::Size() == Size(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
......@@ -94,13 +102,13 @@ struct Sequence
__host__ __device__ static constexpr auto Front()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<0>{});
return At(Number<0>{});
}
__host__ __device__ static constexpr auto Back()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{});
return At(Number<mSize - 1>{});
}
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
......@@ -134,28 +142,28 @@ struct Sequence
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
{
return Sequence<Type::Get(Number<Ns>{})...>{};
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{
return Sequence<Type::Get(Number<Ns>{})...>{};
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Transform(F f)
{
return Sequence<f(Is)...>{};
......@@ -163,8 +171,11 @@ struct Sequence
};
// merge sequence
template <class, class>
struct sequence_merge;
template <typename Seq, typename... Seqs>
struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
......@@ -172,35 +183,41 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>;
};
// generate sequence
template <index_t IBegin, index_t NRemain, class F>
struct sequence_gen_impl
template <typename Seq>
struct sequence_merge<Seq>
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
using type = Seq;
};
template <index_t I, class F>
struct sequence_gen_impl<I, 1, F>
// generate sequence
template <index_t NSize, typename F>
struct sequence_gen
{
static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
template <index_t I, class F>
struct sequence_gen_impl<I, 0, F>
{
using type = Sequence<>;
};
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
template <index_t NSize, class F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
......@@ -232,10 +249,10 @@ struct uniform_sequence_gen
};
// reverse inclusive scan (with init) sequence
template <class, class, index_t>
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce, index_t Init>
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
......@@ -245,41 +262,41 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
template <index_t I, class Reduce, index_t Init>
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using type = Sequence<Reduce{}(I, Init)>;
};
template <class Reduce, index_t Init>
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{
using type = Sequence<>;
};
// split sequence
template <class Seq, index_t I>
template <typename Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.GetSize();
static constexpr index_t NSize = Seq{}.Size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using SeqType0 = decltype(Seq::Extract(range0{}));
using SeqType1 = decltype(Seq::Extract(range1{}));
using left_type = decltype(Seq::Extract(range0{}));
using right_type = decltype(Seq::Extract(range1{}));
};
// reverse sequence
template <class Seq>
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.GetSize();
static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::type,
typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
......@@ -294,44 +311,291 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>;
};
template <class Seq>
struct is_valid_sequence_map
#if 1
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
// not implemented yet, always return true
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Reduce{}(Xs, Ys)...>;
};
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
private:
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
using type = Seq;
};
#endif
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
using new_left_values =
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
using new_left_ids =
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::Size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
};
template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
using type = WorkingY2X;
using sorted_values = Sequence<>;
using sorted_ids = Sequence<>;
};
template <class X2Y>
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t current_value = RemainValues::Front();
static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_uniquified_ids =
typename conditional<is_unique_value,
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
template <typename SeqMap>
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
typename sequence_sort<SeqMap, math::less<index_t>>::type>
{
};
template <typename SeqMap>
struct sequence_map_inverse
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
X2Y::GetSize()>::type;
SeqMap::Size()>::type;
};
template <index_t... Xs, index_t... Ys>
......@@ -442,20 +706,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{};
}
template <class Seq>
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
{
static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq::Reverse()).Reverse();
}
template <class F, index_t... Xs>
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
return Sequence<f(Xs)...>{};
}
template <class F, index_t... Xs, index_t... Ys>
template <typename F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
......@@ -463,7 +733,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return Sequence<f(Xs, Ys)...>{};
}
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
......@@ -474,52 +744,123 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return Sequence<f(Xs, Ys, Zs)...>{};
}
template <class Seq, class Reduce, index_t Init>
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
}
template <class Seq, class Reduce, index_t Init>
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
}
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
constexpr index_t nsize = Sequence<Xs...>::GetSize();
using new_work_seq = typename conditional<RemainMask::Front(),
decltype(WorkSeq::PushBack(RemainSeq::Front())),
WorkSeq>::type;
static_assert(nsize <= 10, "wrong!");
using type =
typename pick_sequence_elements_by_mask_impl<new_work_seq,
decltype(RemainSeq::PopFront()),
decltype(RemainMask::PopFront())>::type;
};
template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
static_assert(Seq::Size() == Mask::Size(), "wrong!");
return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
}
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
{
using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
using type =
typename modify_sequence_elements_by_ids_impl<new_work_seq,
decltype(RemainValues::PopFront()),
decltype(RemainIds::PopFront())>::type;
};
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
for(index_t i = 0; i < Seq::Size(); ++i)
{
result = f(result, Seq::At(i));
}
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
return result;
}
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
// TODO: a generic any_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
{
bool flag = false;
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag || f(Seq::At(i));
}
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
return flag;
}
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
// TODO: a generic all_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
{
bool flag = true;
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag && f(Seq::At(i));
}
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
return flag;
}
} // namespace ck
......
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "sequence.hpp"
namespace ck {
template <index_t... Xs>
__host__ __device__ void print_sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::Size();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck
#endif
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "sequence.hpp"
namespace ck {
namespace detail {
template <index_t>
struct TupleElementKey
{
};
template <typename Key, typename Data>
struct TupleElement
{
__host__ __device__ explicit constexpr TupleElement() : mData() {}
template <typename T>
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
{
}
Data mData;
};
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
{
return x.mData;
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
{
return x.mData;
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
template <typename Indices, typename... Xs>
struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{
}
template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
};
} // namespace detail
template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
{
using base =
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
template <typename... Ys>
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
{
}
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
#endif
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
} // namespace ck
#endif
......@@ -14,7 +14,7 @@ struct vector_type
template <>
struct vector_type<float, 1>
{
typedef float MemoryType;
using MemoryType = float;
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
......@@ -64,6 +64,24 @@ struct vector_type<float, 4>
}
};
template <>
struct vector_type<const float, 1>
{
using MemoryType = const float;
};
template <>
struct vector_type<const float, 2>
{
using MemoryType = const float2_t;
};
template <>
struct vector_type<const float, 4>
{
using MemoryType = const float4_t;
};
} // namespace ck
#endif
......@@ -107,42 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
......@@ -170,43 +139,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
using WeiBlockCopySubLengths_CK = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 0
// for 3x3, 34x34, v1r3, Pascal, bad
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>;
constexpr index_t InBlockCopyDataPerRead_N = 1;
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 1;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
constexpr index_t BlockSize = 256;
......@@ -232,12 +173,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopyDataPerRead_N = 2;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t WeiBlockCopyDataPerAccess_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 256;
......@@ -263,12 +204,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 0
// for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32;
......@@ -282,13 +226,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -298,7 +242,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -324,14 +268,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -347,13 +291,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -365,7 +309,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -393,12 +337,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16;
......@@ -413,13 +357,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
......@@ -429,7 +373,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
......@@ -453,65 +397,67 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_N>{};
<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using namespace ck;
template <typename T, class InDesc, class WeiDesc, class OutDesc, class LeftPads, class RightPads>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
LeftPads,
RightPads,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif
#if 1 // debug
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
LeftPads,
RightPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
......@@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
......@@ -54,19 +47,16 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
// each thread hold 64 data
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
......@@ -80,65 +70,67 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_W = 1;
#elif 1
// each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
constexpr index_t BlockSize = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 2;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 0
// each thread hold 32 data
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
......@@ -152,7 +144,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
......@@ -168,57 +160,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_W, ConvolutionDirection::BackwardWeights>{};
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
GemmNRepeat,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment