Commit 52423948 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'master' into jd_redux

parents b97af4ec 98a2cfcc
#ifndef CK_ARRAY_HPP #ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP #define CK_ARRAY_HPP
#include "Sequence.hpp" #include "sequence.hpp"
#include "functional2.hpp" #include "functional2.hpp"
namespace ck { namespace ck {
template <class TData, index_t NSize> template <typename TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using type = Array<TData, NSize>;
using data_type = TData; using data_type = TData;
static constexpr index_t nSize = NSize; index_t mData[NSize];
index_t mData[nSize]; __host__ __device__ explicit constexpr Array() {}
template <class... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...} __host__ __device__ constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{ {
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
} }
__host__ __device__ static constexpr index_t GetSize() { return NSize; } __host__ __device__ static constexpr index_t Size() { return NSize; }
// TODO: remove
__host__ __device__ static constexpr index_t GetSize() { return Size(); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const __host__ __device__ constexpr const TData& At(Number<I>) const
{ {
static_assert(I < NSize, "wrong!");
return mData[I]; return mData[I];
} }
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
template <index_t I> template <index_t I>
__host__ __device__ TData& operator()(Number<I>) __host__ __device__ constexpr TData& At(Number<I>)
{ {
static_assert(I < NSize, "wrong!");
return mData[I]; return mData[I];
} }
__host__ __device__ TData& operator()(index_t i) { return mData[i]; } __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
template <index_t I> __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
__host__ __device__ constexpr void Set(Number<I>, TData x)
template <typename I>
__host__ __device__ constexpr const TData& operator[](I i) const
{ {
static_assert(I < NSize, "wrong!"); return At(i);
}
mData[I] = x; template <typename I>
__host__ __device__ constexpr TData& operator()(I i)
{
return At(i);
} }
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; } template <typename T>
__host__ __device__ constexpr type& operator=(const T& x)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; });
return *this;
}
struct lambda_PushBack // emulate constexpr lambda struct lambda_PushBack // emulate constexpr lambda
{ {
...@@ -63,7 +82,7 @@ struct Array ...@@ -63,7 +82,7 @@ struct Array
template <index_t I> template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const __host__ __device__ constexpr void operator()(Number<I>) const
{ {
new_array.Set(Number<I>{}, old_array[I]); new_array(Number<I>{}) = old_array[I];
} }
}; };
...@@ -73,19 +92,96 @@ struct Array ...@@ -73,19 +92,96 @@ struct Array
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array)); static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
new_array.Set(Number<NSize>{}, x); new_array(Number<NSize>{}) = x;
return new_array; return new_array;
} }
}; };
// Arr: Array
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
template <typename I>
__host__ __device__ constexpr const data_type& operator[](I i) const
{
return At(i);
}
template <typename I>
__host__ __device__ constexpr data_type& operator()(I i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr type& operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
Arr& mArray;
};
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
// TODO: remove this
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>) __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{ {
return Array<index_t, sizeof...(Is)>{Is...}; return Array<index_t, sizeof...(Is)>{Is...};
} }
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array() __host__ __device__ constexpr auto make_zero_array()
{ {
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{}; constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
...@@ -93,7 +189,7 @@ __host__ __device__ constexpr auto make_zero_array() ...@@ -93,7 +189,7 @@ __host__ __device__ constexpr auto make_zero_array()
return zero_array; return zero_array;
} }
template <class TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/) Sequence<IRs...> /*new2old*/)
{ {
...@@ -104,7 +200,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData ...@@ -104,7 +200,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return Array<TData, NSize>{old_array[IRs]...}; return Array<TData, NSize>{old_array[IRs]...};
} }
template <class TData, index_t NSize, class MapOld2New> template <typename TData, index_t NSize, typename MapOld2New>
struct lambda_reorder_array_given_old2new struct lambda_reorder_array_given_old2new
{ {
const Array<TData, NSize>& old_array; const Array<TData, NSize>& old_array;
...@@ -121,13 +217,13 @@ struct lambda_reorder_array_given_old2new ...@@ -121,13 +217,13 @@ struct lambda_reorder_array_given_old2new
{ {
TData old_data = old_array[IOldDim]; TData old_data = old_array[IOldDim];
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{}); constexpr index_t INewDim = MapOld2New::At(Number<IOldDim>{});
new_array.Set(Number<INewDim>{}, old_data); new_array(Number<INewDim>{}) = old_data;
} }
}; };
template <class TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*old2new*/) Sequence<IRs...> /*old2new*/)
{ {
...@@ -143,7 +239,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData ...@@ -143,7 +239,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array; return new_array;
} }
template <class TData, index_t NSize, class ExtractSeq> template <typename TData, index_t NSize, typename ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq) __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{ {
Array<TData, ExtractSeq::GetSize()> new_array; Array<TData, ExtractSeq::GetSize()> new_array;
...@@ -152,12 +248,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_ ...@@ -152,12 +248,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract"); static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; }); static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
return new_array; return new_array;
} }
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math // emulate constepxr lambda for array
template <typename F, typename X, typename Y, typename Z>
struct lambda_array_math struct lambda_array_math
{ {
const F& f; const F& f;
...@@ -174,13 +271,12 @@ struct lambda_array_math ...@@ -174,13 +271,12 @@ struct lambda_array_math
__host__ __device__ constexpr void operator()(Number<IDim_>) const __host__ __device__ constexpr void operator()(Number<IDim_>) const
{ {
constexpr auto IDim = Number<IDim_>{}; constexpr auto IDim = Number<IDim_>{};
z(IDim) = f(x[IDim], y[IDim]);
z.Set(IDim, f(x[IDim], y[IDim]));
} }
}; };
// Array = Array + Array // Array = Array + Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
...@@ -195,7 +291,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, ...@@ -195,7 +291,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
} }
// Array = Array - Array // Array = Array - Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
...@@ -210,7 +306,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, ...@@ -210,7 +306,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
} }
// Array += Array // Array += Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{ {
a = a + b; a = a + b;
...@@ -218,14 +314,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat ...@@ -218,14 +314,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
} }
// Array -= Array // Array -= Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{ {
a = a - b; a = a - b;
return a; return a;
} }
// Array = Array + Sequence // Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -242,7 +338,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is. ...@@ -242,7 +338,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Array - Sequence // Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -259,7 +355,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is. ...@@ -259,7 +355,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Array * Sequence // Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -276,7 +372,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is. ...@@ -276,7 +372,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Sequence - Array // Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -292,7 +388,21 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi ...@@ -292,7 +388,21 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return result; return result;
} }
template <class TData, index_t NSize, class Reduce> // Array = Array * TData
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator*(TData v, Array<TData, NSize> a)
{
Array<TData, NSize> result;
for(index_t i = 0; i < NSize; ++i)
{
result(i) = a[i] * v;
}
return result;
}
template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr TData __host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init) accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{ {
...@@ -305,89 +415,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init) ...@@ -305,89 +415,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return result; return result;
} }
template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "array.hpp"
namespace ck {
template <index_t NSize>
__host__ __device__ void print_array(const char* s, Array<uint32_t, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
template <index_t NSize>
__host__ __device__ void print_array(const char* s, Array<int32_t, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %d, {%d}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %d, {%d %d}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %d, {%d %d %d}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %d, {%d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %d, {%d %d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
...@@ -4,16 +4,26 @@ ...@@ -4,16 +4,26 @@
#include "config.hpp" #include "config.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "math.hpp" #include "math.hpp"
#include "vector_type.hpp" #include "vector_type.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "Array.hpp" #include "sequence_helper.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_INTRINSIC
#include "amd_intrinsic.hpp"
#endif
#endif #endif
...@@ -4,29 +4,47 @@ ...@@ -4,29 +4,47 @@
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INTRINSIC 1
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck { namespace ck {
enum address_space_t
{
generic = 0,
global = 3
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, HIP compiler need this definition to generate optimal load and store // For some reason, HIP compiler need this definition to generate optimal load and store
// instruction // instruction
typedef float float2_t __attribute__((ext_vector_type(2))); typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4))); typedef float float4_t __attribute__((ext_vector_type(4)));
using index_t = uint32_t; typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
template <class T> // data type conversion
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1) template <typename T>
struct type_convert
{ {
d += s0 * s1; template <typename X>
} __device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
} // namespace ck } // namespace ck
......
...@@ -6,17 +6,30 @@ ...@@ -6,17 +6,30 @@
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "helper_cuda.h" #include "helper_cuda.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_NVIDIA 1 #define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_USE_AMD_INTRINSIC 0
#define CK_USE_AMD_INLINE_ASM 0 #define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck { namespace ck {
enum address_space_t
{
generic = 0,
global = generic
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, CUDA need this definition, otherwise // For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and // compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct // kernel would produce wrong result, indicating the compiler fail to generate correct
...@@ -24,7 +37,16 @@ namespace ck { ...@@ -24,7 +37,16 @@ namespace ck {
using float2_t = float2; using float2_t = float2;
using float4_t = float4; using float4_t = float4;
using index_t = uint32_t; // data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
template <class T> template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1) __device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
......
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL_HPP #define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "type.hpp"
namespace ck { namespace ck {
// TODO: right? wrong?
struct forwarder struct forwarder
{ {
template <typename T> template <typename T>
...@@ -17,12 +19,30 @@ struct forwarder ...@@ -17,12 +19,30 @@ struct forwarder
struct swallow struct swallow
{ {
template <class... Ts> template <typename... Ts>
__host__ __device__ constexpr swallow(Ts&&...) __host__ __device__ constexpr swallow(Ts&&...)
{ {
} }
}; };
template <typename T>
struct logical_and
{
constexpr bool operator()(const T& x, const T& y) const { return x && y; }
};
template <typename T>
struct logical_or
{
constexpr bool operator()(const T& x, const T& y) const { return x || y; }
};
template <typename T>
struct logical_not
{
constexpr bool operator()(const T& x) const { return !x; }
};
// Emulate if constexpr // Emulate if constexpr
template <bool> template <bool>
struct static_if; struct static_if;
...@@ -32,7 +52,7 @@ struct static_if<true> ...@@ -32,7 +52,7 @@ struct static_if<true>
{ {
using Type = static_if<true>; using Type = static_if<true>;
template <class F> template <typename F>
__host__ __device__ constexpr auto operator()(F f) const __host__ __device__ constexpr auto operator()(F f) const
{ {
// This is a trick for compiler: // This is a trick for compiler:
...@@ -43,7 +63,7 @@ struct static_if<true> ...@@ -43,7 +63,7 @@ struct static_if<true>
return Type{}; return Type{};
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Else(F) __host__ __device__ static constexpr auto Else(F)
{ {
return Type{}; return Type{};
...@@ -55,13 +75,13 @@ struct static_if<false> ...@@ -55,13 +75,13 @@ struct static_if<false>
{ {
using Type = static_if<false>; using Type = static_if<false>;
template <class F> template <typename F>
__host__ __device__ constexpr auto operator()(F) const __host__ __device__ constexpr auto operator()(F) const
{ {
return Type{}; return Type{};
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Else(F f) __host__ __device__ static constexpr auto Else(F f)
{ {
// This is a trick for compiler: // This is a trick for compiler:
...@@ -73,5 +93,23 @@ struct static_if<false> ...@@ -73,5 +93,23 @@ struct static_if<false>
} }
}; };
template <bool predicate, class X, class Y>
struct conditional;
template <class X, class Y>
struct conditional<true, X, Y>
{
using type = X;
};
template <class X, class Y>
struct conditional<false, X, Y>
{
using type = Y;
};
template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck } // namespace ck
#endif #endif
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL2_HPP #define CK_FUNCTIONAL2_HPP
#include "functional.hpp" #include "functional.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
namespace ck { namespace ck {
namespace detail {
template <class> template <class>
struct static_for_impl; struct static_for_impl;
...@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>> ...@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
} }
}; };
} // namespace detail
// F signature: F(Number<Iter>) // F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment> template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for struct static_for
...@@ -33,38 +37,10 @@ struct static_for ...@@ -33,38 +37,10 @@ struct static_for
template <class F> template <class F>
__host__ __device__ constexpr void operator()(F f) const __host__ __device__ constexpr void operator()(F f) const
{ {
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f); detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
} }
}; };
template <class Seq, class Reduce>
struct lambda_accumulate_on_sequence
{
const Reduce& f;
index_t& result;
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
: f(f_), result(result_)
{
}
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return result = f(result, Seq::Get(IDim{}));
}
};
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
return result;
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,25 +3,12 @@ ...@@ -3,25 +3,12 @@
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "Array.hpp" #include "array.hpp"
namespace ck { namespace ck {
template <class> namespace detail {
struct is_static : integral_constant<bool, false>
{
};
template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
// Orders: Sequence<...> // Orders: Sequence<...>
...@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders> ...@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
} }
}; };
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
// Orders: Sequence<...> // Orders: Sequence<...>
template <class RemainLengths, class Orders> template <class RemainLengths, class Orders>
...@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders> ...@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
} }
}; };
} // namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop // Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each // Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension // dimension
...@@ -139,7 +128,8 @@ struct ford ...@@ -139,7 +128,8 @@ struct ford
for(index_t i = 0; i < ordered_lengths.Front(); ++i) for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{ {
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i}); detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Array<index_t, 1>{i});
} }
} }
}; };
......
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "sequence.hpp"
#include "tuple.hpp"
#include "array.hpp"
namespace ck {
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const
{
return f(x.At(Number<Is>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x)
{
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
}
} // namespace ck
#endif
...@@ -13,51 +13,5 @@ struct integral_constant ...@@ -13,51 +13,5 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept { return value; }
}; };
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
};
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "config.hpp" #include "config.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -31,6 +32,12 @@ struct multiplies ...@@ -31,6 +32,12 @@ struct multiplies
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
}; };
template <class T>
struct maxer
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <class T> template <class T>
struct integer_divide_ceiler struct integer_divide_ceiler
{ {
...@@ -98,6 +105,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs) ...@@ -98,6 +105,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return max(x, xs...); return max(x, xs...);
} }
template <class T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <class T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
} // namespace math } // namespace math
} // namspace ck } // namspace ck
......
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
...@@ -2,29 +2,34 @@ ...@@ -2,29 +2,34 @@
#define CK_SEQUENCE_HPP #define CK_SEQUENCE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "math.hpp"
namespace ck { namespace ck {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...> template <index_t...>
struct Sequence; struct Sequence;
template <class Seq, index_t I> template <typename Seq, index_t I>
struct sequence_split; struct sequence_split;
template <class> template <typename>
struct sequence_reverse; struct sequence_reverse;
template <class> template <typename>
struct sequence_map_inverse; struct sequence_map_inverse;
template <class> template <typename>
struct is_valid_sequence_map; struct is_valid_sequence_map;
template <index_t I, index_t... Is> template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>); __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq> template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq); __host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is> template <index_t... Is>
...@@ -35,9 +40,11 @@ struct Sequence ...@@ -35,9 +40,11 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; } __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
__host__ __device__ static constexpr auto GetSize() { return Size(); }
__host__ __device__ static constexpr index_t GetImpl(index_t I) __host__ __device__ static constexpr index_t At(index_t I)
{ {
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0}; const index_t mData[mSize + 1] = {Is..., 0};
...@@ -45,23 +52,24 @@ struct Sequence ...@@ -45,23 +52,24 @@ struct Sequence
} }
template <index_t I> template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>) __host__ __device__ static constexpr auto At(Number<I>)
{ {
static_assert(I < mSize, "wrong! I too large"); static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{}; return Number<At(I)>{};
} }
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ static constexpr auto Get(Number<I>)
{ {
return Get(Number<I>{}); return At(Number<I>{});
} }
// make sure I is constepxr if you want a constexpr return type template <typename I>
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); } __host__ __device__ constexpr auto operator[](I i) const
{
return At(i);
}
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
...@@ -71,14 +79,14 @@ struct Sequence ...@@ -71,14 +79,14 @@ struct Sequence
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map"); static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::At(Number<IRs>{})...>{};
} }
// MapOld2New is Sequence<...> // MapOld2New is Sequence<...>
template <class MapOld2New> template <typename MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{ {
static_assert(MapOld2New::GetSize() == GetSize(), static_assert(MapOld2New::Size() == Size(),
"wrong! reorder map should have the same size as Sequence to be rerodered"); "wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map"); static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
...@@ -94,13 +102,13 @@ struct Sequence ...@@ -94,13 +102,13 @@ struct Sequence
__host__ __device__ static constexpr auto Front() __host__ __device__ static constexpr auto Front()
{ {
static_assert(mSize > 0, "wrong!"); static_assert(mSize > 0, "wrong!");
return Get(Number<0>{}); return At(Number<0>{});
} }
__host__ __device__ static constexpr auto Back() __host__ __device__ static constexpr auto Back()
{ {
static_assert(mSize > 0, "wrong!"); static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{}); return At(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
...@@ -134,28 +142,28 @@ struct Sequence ...@@ -134,28 +142,28 @@ struct Sequence
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...) __host__ __device__ static constexpr auto Extract(Number<Ns>...)
{ {
return Sequence<Type::Get(Number<Ns>{})...>{}; return Sequence<Type::At(Number<Ns>{})...>{};
} }
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>) __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{ {
return Sequence<Type::Get(Number<Ns>{})...>{}; return Sequence<Type::At(Number<Ns>{})...>{};
} }
template <index_t I, index_t X> template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>) __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{ {
static_assert(I < GetSize(), "wrong!"); static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>; using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right); return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Transform(F f) __host__ __device__ static constexpr auto Transform(F f)
{ {
return Sequence<f(Is)...>{}; return Sequence<f(Is)...>{};
...@@ -163,8 +171,11 @@ struct Sequence ...@@ -163,8 +171,11 @@ struct Sequence
}; };
// merge sequence // merge sequence
template <class, class> template <typename Seq, typename... Seqs>
struct sequence_merge; struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...@@ -172,35 +183,41 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> ...@@ -172,35 +183,41 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>; using type = Sequence<Xs..., Ys...>;
}; };
// generate sequence template <typename Seq>
template <index_t IBegin, index_t NRemain, class F> struct sequence_merge<Seq>
struct sequence_gen_impl
{ {
static constexpr index_t NRemainLeft = NRemain / 2; using type = Seq;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
}; };
template <index_t I, class F> // generate sequence
struct sequence_gen_impl<I, 1, F> template <index_t NSize, typename F>
struct sequence_gen
{ {
static constexpr index_t Is = F{}(Number<I>{}); template <index_t IBegin, index_t NRemain, typename G>
using type = Sequence<Is>; struct sequence_gen_impl
}; {
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
template <index_t I, class F> using type = typename sequence_merge<
struct sequence_gen_impl<I, 0, F> typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
{ typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
using type = Sequence<>; };
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
template <index_t NSize, class F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type; using type = typename sequence_gen_impl<0, NSize, F>::type;
}; };
...@@ -232,10 +249,10 @@ struct uniform_sequence_gen ...@@ -232,10 +249,10 @@ struct uniform_sequence_gen
}; };
// reverse inclusive scan (with init) sequence // reverse inclusive scan (with init) sequence
template <class, class, index_t> template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan; struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce, index_t Init> template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{ {
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type; using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
...@@ -245,41 +262,41 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init> ...@@ -245,41 +262,41 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type; using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
}; };
template <index_t I, class Reduce, index_t Init> template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{ {
using type = Sequence<Reduce{}(I, Init)>; using type = Sequence<Reduce{}(I, Init)>;
}; };
template <class Reduce, index_t Init> template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
// split sequence // split sequence
template <class Seq, index_t I> template <typename Seq, index_t I>
struct sequence_split struct sequence_split
{ {
static constexpr index_t NSize = Seq{}.GetSize(); static constexpr index_t NSize = Seq{}.Size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type; using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using SeqType0 = decltype(Seq::Extract(range0{})); using left_type = decltype(Seq::Extract(range0{}));
using SeqType1 = decltype(Seq::Extract(range1{})); using right_type = decltype(Seq::Extract(range1{}));
}; };
// reverse sequence // reverse sequence
template <class Seq> template <typename Seq>
struct sequence_reverse struct sequence_reverse
{ {
static constexpr index_t NSize = Seq{}.GetSize(); static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>; using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge< using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::type, typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::SeqType0>::type>::type; typename sequence_reverse<typename seq_split::left_type>::type>::type;
}; };
template <index_t I> template <index_t I>
...@@ -294,44 +311,291 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -294,44 +311,291 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>; using type = Sequence<I1, I0>;
}; };
template <class Seq> #if 1
struct is_valid_sequence_map template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{ {
// not implemented yet, always return true using type = typename sequence_reduce<Reduce,
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{}; Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
// TODO: add proper check for is_valid, something like: template <typename Reduce, index_t... Xs, index_t... Ys>
// static constexpr bool value = struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type, {
// typename sequence_sort<Seq>::SortedSeqType>{}; using type = Sequence<Reduce{}(Xs, Ys)...>;
}; };
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain> template <typename Reduce, typename Seq>
struct sequence_map_inverse_impl struct sequence_reduce<Reduce, Seq>
{ {
private: using type = Seq;
static constexpr auto new_y2x = };
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{}); #endif
public: template <typename Values, typename Ids, typename Compare>
using type = struct sequence_sort_impl
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type; {
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
using new_left_values =
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
using new_left_ids =
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::Size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
}; };
template <class X2Y, class WorkingY2X, index_t XBegin> template <typename Compare>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0> struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{ {
using type = WorkingY2X; using sorted_values = Sequence<>;
using sorted_ids = Sequence<>;
}; };
template <class X2Y> template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t current_value = RemainValues::Front();
static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_uniquified_ids =
typename conditional<is_unique_value,
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
template <typename SeqMap>
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
typename sequence_sort<SeqMap, math::less<index_t>>::type>
{
};
template <typename SeqMap>
struct sequence_map_inverse struct sequence_map_inverse
{ {
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
using type = using type =
typename sequence_map_inverse_impl<X2Y, typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type, typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0, 0,
X2Y::GetSize()>::type; SeqMap::Size()>::type;
}; };
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
...@@ -442,20 +706,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -442,20 +706,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{}; return Sequence<Is...>{};
} }
template <class Seq> template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq) __host__ __device__ constexpr auto sequence_pop_back(Seq)
{ {
static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!"); static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq::Reverse()).Reverse(); return sequence_pop_front(Seq::Reverse()).Reverse();
} }
template <class F, index_t... Xs> template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>) __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{ {
return Sequence<f(Xs)...>{}; return Sequence<f(Xs)...>{};
} }
template <class F, index_t... Xs, index_t... Ys> template <typename F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
{ {
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same"); static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
...@@ -463,7 +733,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq ...@@ -463,7 +733,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return Sequence<f(Xs, Ys)...>{}; return Sequence<f(Xs, Ys)...>{};
} }
template <class F, index_t... Xs, index_t... Ys, index_t... Zs> template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>) transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{ {
...@@ -474,52 +744,123 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>) ...@@ -474,52 +744,123 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return Sequence<f(Xs, Ys, Zs)...>{}; return Sequence<f(Xs, Ys, Zs)...>{};
} }
template <class Seq, class Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>) __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{}; return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
} }
template <class Seq, class Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>) __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <index_t... Xs> template <typename Seq, index_t... Is>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{ {
constexpr index_t nsize = Sequence<Xs...>::GetSize(); using new_work_seq = typename conditional<RemainMask::Front(),
decltype(WorkSeq::PushBack(RemainSeq::Front())),
WorkSeq>::type;
static_assert(nsize <= 10, "wrong!"); using type =
typename pick_sequence_elements_by_mask_impl<new_work_seq,
decltype(RemainSeq::PopFront()),
decltype(RemainMask::PopFront())>::type;
};
template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
static_assert(Seq::Size() == Mask::Size(), "wrong!");
return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
}
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
{
using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
using type =
typename modify_sequence_elements_by_ids_impl<new_work_seq,
decltype(RemainValues::PopFront()),
decltype(RemainIds::PopFront())>::type;
};
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); for(index_t i = 0; i < Seq::Size(); ++i)
{
result = f(result, Seq::At(i));
}
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); return result;
}
static_if<nsize == 5>{}( // TODO: a generic any_of for any container
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
{
bool flag = false;
static_if<nsize == 6>{}( for(index_t i = 0; i < Seq::Size(); ++i)
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); {
flag = flag || f(Seq::At(i));
}
static_if<nsize == 7>{}( return flag;
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); }
static_if<nsize == 8>{}( // TODO: a generic all_of for any container
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
{
bool flag = true;
static_if<nsize == 9>{}( for(index_t i = 0; i < Seq::Size(); ++i)
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); {
flag = flag && f(Seq::At(i));
}
static_if<nsize == 10>{}( return flag;
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
} }
} // namespace ck } // namespace ck
......
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "sequence.hpp"
namespace ck {
template <index_t... Xs>
__host__ __device__ void print_sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::Size();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck
#endif
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "sequence.hpp"
namespace ck {
namespace detail {
template <index_t>
struct TupleElementKey
{
};
template <typename Key, typename Data>
struct TupleElement
{
__host__ __device__ explicit constexpr TupleElement() : mData() {}
template <typename T>
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
{
}
Data mData;
};
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
{
return x.mData;
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
{
return x.mData;
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
template <typename Indices, typename... Xs>
struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{
}
template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
};
} // namespace detail
template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
{
using base =
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
template <typename... Ys>
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
{
}
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
#endif
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
} // namespace ck
#endif
...@@ -14,7 +14,7 @@ struct vector_type ...@@ -14,7 +14,7 @@ struct vector_type
template <> template <>
struct vector_type<float, 1> struct vector_type<float, 1>
{ {
typedef float MemoryType; using MemoryType = float;
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
...@@ -64,6 +64,24 @@ struct vector_type<float, 4> ...@@ -64,6 +64,24 @@ struct vector_type<float, 4>
} }
}; };
template <>
struct vector_type<const float, 1>
{
using MemoryType = const float;
};
template <>
struct vector_type<const float, 2>
{
using MemoryType = const float2_t;
};
template <>
struct vector_type<const float, 4>
{
using MemoryType = const float4_t;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -107,42 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -107,42 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
...@@ -170,43 +139,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -170,43 +139,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; using WeiBlockCopySubLengths_CK = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 3x3, 34x34, v1r3, Pascal, bad
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>;
constexpr index_t InBlockCopyDataPerRead_N = 1;
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 1;
#elif 0 #elif 0
// for 3x3, 34x34, v1r1, Vega 20 // for 3x3, 34x34, v1r1, Vega 20
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -232,12 +173,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -232,12 +173,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopyDataPerRead_N = 2; constexpr index_t InBlockCopyDataPerAccess_N = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 2; constexpr index_t WeiBlockCopyDataPerAccess_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 1 #elif 1
// for 3x3, 34x34, v1r3, Vega 20 // for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -263,12 +204,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -263,12 +204,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>; using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 0 #elif 0
// for 3x3, 56x56, v1r1, Pascal // for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
...@@ -282,13 +226,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -282,13 +226,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8; constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -298,7 +242,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -298,7 +242,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -324,14 +268,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -324,14 +268,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 1; constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1; constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -347,13 +291,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -347,13 +291,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8; constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -365,7 +309,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -365,7 +309,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -393,12 +337,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -393,12 +337,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>; using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0 #elif 0
// for 1x1, 28x28, v1r1, Pascal // for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
...@@ -413,13 +357,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -413,13 +357,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
...@@ -429,7 +373,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -429,7 +373,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -453,65 +397,67 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -453,65 +397,67 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8; constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2; constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4; constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
constexpr index_t GridSize = constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i) constexpr auto gridwise_conv =
{
constexpr auto gridwise_conv =
#if 0 #if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0 #elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0 #elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif #endif
<GridSize, <GridSize,
BlockSize, BlockSize,
T, T,
decltype(in_chwn_desc), decltype(in_chwn_desc),
decltype(wei_cyxk_desc), decltype(wei_cyxk_desc),
decltype(out_khwn_desc), decltype(out_khwn_desc),
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
NPerThread, NPerThread,
KPerThread, KPerThread,
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB, GemmDataPerReadB,
InBlockCopyClusterLengths_CHWN, InBlockCopySubLengths_CHWN,
InBlockCopyDataPerRead_N, InBlockCopyClusterLengths_CHWN,
WeiBlockCopyDataPerRead_K, InBlockCopyDataPerAccess_N,
OutThreadCopyDataPerWrite_N>{}; WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using namespace ck;
template <typename T, class InDesc, class WeiDesc, class OutDesc, class LeftPads, class RightPads>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
LeftPads,
RightPads,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif
#if 1 // debug
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
LeftPads,
RightPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
...@@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0); constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
...@@ -54,19 +47,16 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -54,19 +47,16 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1 #if 1
// each thread hold 64 data // BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
...@@ -80,65 +70,67 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -80,65 +70,67 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>; using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>; using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
constexpr index_t OutThreadCopyDataPerAccess_W = 1; // BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
#elif 1 constexpr index_t BlockSize = 64;
// each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 8;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 2; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>; using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// each thread hold 32 data // BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmMPerThreadSubC = 2;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
...@@ -152,7 +144,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -152,7 +144,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
...@@ -168,57 +160,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -168,57 +160,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif #endif
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t GridSize = constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i) constexpr auto gridwise_conv =
{
constexpr auto gridwise_conv =
#if 0 #if 0
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
#else #else
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
#endif #endif
<GridSize, <GridSize,
BlockSize, BlockSize,
T, T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
BPerBlock, BPerBlock,
KPerBlock, KPerBlock,
EPerBlock, EPerBlock,
N1, GemmNRepeat,
N2, GemmMPerThreadSubC,
GemmMPerThreadSubC, GemmNPerThreadSubC,
GemmNPerThreadSubC, GemmMLevel0Cluster,
GemmMLevel0Cluster, GemmNLevel0Cluster,
GemmNLevel0Cluster, GemmMLevel1Cluster,
GemmMLevel1Cluster, GemmNLevel1Cluster,
GemmNLevel1Cluster, GemmKPerThreadLoop,
GemmKPerThreadLoop, GemmDataPerReadA,
GemmDataPerReadA, GemmDataPerReadB,
GemmDataPerReadB, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopySubLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopySrcAccessOrder,
InBlockCopySrcAccessOrder, InBlockCopyDstAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopySrcDataPerRead_B,
InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_N2,
InBlockCopyDstDataPerWrite_N2, WeiBlockCopySubLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K>{};
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_W, ConvolutionDirection::BackwardWeights>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment