Commit 674c405f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent ffa7e4be
...@@ -17,29 +17,34 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc, ...@@ -17,29 +17,34 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
const Array<index_t, 2> in_left_pads, const Array<index_t, 2> in_left_pads,
const Array<index_t, 2> in_right_pads) const Array<index_t, 2> in_right_pads)
{ {
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0); constexpr auto i0 = Number<0>{};
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1); constexpr auto i1 = Number<1>{};
const index_t K = out_n_k_ho_wo_global_desc.GetLength(1); constexpr auto i2 = Number<2>{};
constexpr auto i3 = Number<3>{};
const index_t Y = wei_k_c_y_x_global_desc.GetLength(2); const index_t N = in_n_c_hi_wi_global_desc.GetLength(i0);
const index_t X = wei_k_c_y_x_global_desc.GetLength(3); const index_t C = in_n_c_hi_wi_global_desc.GetLength(i1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(i1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(2); const index_t Y = wei_k_c_y_x_global_desc.GetLength(i2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(3); const index_t X = wei_k_c_y_x_global_desc.GetLength(i3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(2); const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(i2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(3); const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(i3);
const index_t ConvStrideH = conv_strides[0]; const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(i2);
const index_t ConvStrideW = conv_strides[1]; const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(i3);
const index_t ConvDilationH = conv_dilations[0]; const index_t ConvStrideH = conv_strides[i0];
const index_t ConvDilationW = conv_dilations[1]; const index_t ConvStrideW = conv_strides[i1];
const index_t InLeftPadH = in_left_pads[0]; const index_t ConvDilationH = conv_dilations[i0];
const index_t InLeftPadW = in_left_pads[1]; const index_t ConvDilationW = conv_dilations[i1];
const index_t InRightPadH = in_right_pads[0];
const index_t InRightPadW = in_right_pads[1]; const index_t InLeftPadH = in_left_pads[i0];
const index_t InLeftPadW = in_left_pads[i1];
const index_t InRightPadH = in_right_pads[i0];
const index_t InRightPadW = in_right_pads[i1];
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2(
...@@ -58,8 +63,8 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc, ...@@ -58,8 +63,8 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(2); const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(i2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(3); const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(i3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor_v2( const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor_v2(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
......
...@@ -31,7 +31,7 @@ struct DynamicPassThrough ...@@ -31,7 +31,7 @@ struct DynamicPassThrough
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0]; idx_low(Number<0>{}) = idx_up[Number<0>{}];
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -44,7 +44,7 @@ struct DynamicPassThrough ...@@ -44,7 +44,7 @@ struct DynamicPassThrough
UpIdx::Size() == 1, UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0]; idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -92,7 +92,7 @@ struct DynamicLeftPad ...@@ -92,7 +92,7 @@ struct DynamicLeftPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0] - left_pad_; idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -106,7 +106,7 @@ struct DynamicLeftPad ...@@ -106,7 +106,7 @@ struct DynamicLeftPad
UpIdx::Size() == 1, UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0]; idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -120,7 +120,7 @@ struct DynamicLeftPad ...@@ -120,7 +120,7 @@ struct DynamicLeftPad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{ {
return SkipIsValidCheck || (idx_up[0] >= left_pad_); return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
} }
}; };
...@@ -158,7 +158,7 @@ struct DynamicRightPad ...@@ -158,7 +158,7 @@ struct DynamicRightPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0]; idx_low(Number<0>{}) = idx_up[Number<0>{}];
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -172,7 +172,7 @@ struct DynamicRightPad ...@@ -172,7 +172,7 @@ struct DynamicRightPad
UpIdx::Size() == 1, UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0]; idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -186,7 +186,7 @@ struct DynamicRightPad ...@@ -186,7 +186,7 @@ struct DynamicRightPad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{ {
return SkipIsValidCheck || (idx_up[0] < low_length_); return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
} }
}; };
...@@ -228,13 +228,11 @@ struct DynamicEmbed ...@@ -228,13 +228,11 @@ struct DynamicEmbed
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(0) = coefficients_[NDimUp]; idx_low(Number<0>{}) = coefficients_[Number<NDimUp>{}];
#pragma unroll static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
for(index_t i = 0; i < NDimUp; ++i) idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
{ });
idx_low(0) += idx_up[i] * coefficients_[i];
}
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -247,13 +245,10 @@ struct DynamicEmbed ...@@ -247,13 +245,10 @@ struct DynamicEmbed
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_diff_low(0) = 0; idx_diff_low(Number<0>{}) = 0;
#pragma unroll static_for<0, NDimUp, 1>{}(
for(index_t i = 0; i < NDimUp; ++i) [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
{
idx_diff_low(0) += idx_diff_up[i] * coefficients_[i];
}
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -310,16 +305,14 @@ struct DynamicMerge ...@@ -310,16 +305,14 @@ struct DynamicMerge
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
index_t tmp = idx_up[0]; index_t tmp = idx_up[Number<0>{}];
#pragma unroll static_for<0, NDimLow - 1, 1>{}([&idx_low, &tmp, this](auto i) {
for(index_t i = 0; i < NDimLow - 1; ++i) idx_low(i) = tmp / this->low_lengths_scan_[i];
{ tmp -= idx_low[i] * this->low_lengths_scan_[i];
idx_low(i) = tmp / low_lengths_scan_[i]; });
tmp -= idx_low[i] * low_lengths_scan_[i];
}
idx_low(NDimLow - 1) = tmp; idx_low(Number<NDimLow - 1>{}) = tmp;
} }
// idx_diff_low depends on idx_low_old, so idx_low need to be up-to-date // idx_diff_low depends on idx_low_old, so idx_low need to be up-to-date
...@@ -336,15 +329,13 @@ struct DynamicMerge ...@@ -336,15 +329,13 @@ struct DynamicMerge
LowIdx::Size() == NDimLow && UpIdx::Size() == 1, LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
#if 1 #if 0
// I only want to do this check, if idx_diff_up is know at compile-time // I only want to do this check, if idx_diff_up is know at compile-time
if(idx_diff_up[0] == 0) if(idx_diff_up[Number<0>{}] == 0)
{ {
#pragma unroll static_for<0, NDimLow, 1>{}([&idx_diff_low](auto i){
for(index_t i = 0; i < NDimLow; ++i)
{
idx_diff_low(i) = 0; idx_diff_low(i) = 0;
} });
return; return;
} }
...@@ -370,9 +361,7 @@ struct DynamicMerge ...@@ -370,9 +361,7 @@ struct DynamicMerge
// do not need to check the first dimension // do not need to check the first dimension
index_t carry = 0; index_t carry = 0;
#pragma unroll static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
for(index_t i = NDimLow - 1; i > 0; --i)
{
// this should be saved in SGPR as well // this should be saved in SGPR as well
index_t idx_low_length_minus_idx_diff_low_const = index_t idx_low_length_minus_idx_diff_low_const =
low_lengths_[i] - idx_diff_low_const[i]; low_lengths_[i] - idx_diff_low_const[i];
...@@ -401,9 +390,9 @@ struct DynamicMerge ...@@ -401,9 +390,9 @@ struct DynamicMerge
#if 0 #if 0
carry = do_borrow ? -1 : carry; carry = do_borrow ? -1 : carry;
#endif #endif
} });
idx_diff_low(0) = idx_diff_low_const[0] + carry; idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
...@@ -453,13 +442,10 @@ struct DynamicUnMerge ...@@ -453,13 +442,10 @@ struct DynamicUnMerge
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& idx_up) const
{ {
idx_low(0) = idx_up[NDimUp]; idx_low(Number<0>{}) = idx_up[Number<NDimUp>{}];
#pragma unroll static_for<0, NDimUp - 1, 1>{}(
for(index_t i = 0; i < NDimUp - 1; ++i) [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
{
idx_low(0) += idx_up[i] * up_lengths_scan_[i];
}
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -512,7 +498,7 @@ struct DynamicFreeze ...@@ -512,7 +498,7 @@ struct DynamicFreeze
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(0) = low_idx_; idx_low(Number<0>{}) = low_idx_;
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -521,7 +507,7 @@ struct DynamicFreeze ...@@ -521,7 +507,7 @@ struct DynamicFreeze
const LowIdx& /* idx_low_old */, const LowIdx& /* idx_low_old */,
const UpIdx& /* idx_up_old */) const UpIdx& /* idx_up_old */)
{ {
idx_diff_low(0) = index_t{0}; idx_diff_low(Number<0>{}) = index_t{Number<0>{}};
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......
...@@ -105,9 +105,10 @@ struct DynamicTensorDescriptor_v2 ...@@ -105,9 +105,10 @@ struct DynamicTensorDescriptor_v2
return GetNumOfVisibleDimension(); return GetNumOfVisibleDimension();
} }
__host__ __device__ constexpr index_t GetLength(index_t idim) const template <index_t IDim>
__host__ __device__ constexpr index_t GetLength(Number<IDim>) const
{ {
return visible_lengths_[idim]; return visible_lengths_[Number<IDim>{}];
} }
__host__ __device__ constexpr const auto& GetLengths() const { return visible_lengths_; } __host__ __device__ constexpr const auto& GetLengths() const { return visible_lengths_; }
......
...@@ -161,82 +161,29 @@ struct Array ...@@ -161,82 +161,29 @@ struct Array
} }
}; };
// Arr: Array template <typename X, typename... Xs>
// Picks: Sequence<...> __host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
template <typename Arr, typename Picks>
struct ArrayElementPicker
{ {
using type = ArrayElementPicker; return Array<X, sizeof...(xs) + 1>{{x, xs...}};
using data_type = typename Arr::data_type; }
__host__ __device__ constexpr ArrayElementPicker() = delete; template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array} {
{ Array<typename T::data_type, T::Size()> y;
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
__host__ __device__ constexpr const data_type& operator[](index_t i) const
{
index_t ip = Picks{}[i];
return mArray[ip];
}
__host__ __device__ constexpr data_type& operator()(index_t i)
{
index_t ip = Picks{}[i];
return mArray(ip);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
template <typename T>
__host__ __device__ constexpr auto operator+=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) += a[i]; });
return *this;
}
template <typename T> static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
__host__ __device__ constexpr auto operator-=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) -= a[i]; });
return *this; return y;
} }
private: template <typename TData, index_t NSize>
Arr& mArray; __host__ __device__ constexpr auto make_zero_array()
}; {
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
constexpr auto zero_array = to_array(zero_sequence);
return zero_array;
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_ARRAY_ELEMENT_PICKER_HPP
#define CK_ARRAY_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "sequence.hpp"
namespace ck {
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
private:
Arr& mArray;
};
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
} // namespace ck
#endif
...@@ -2,39 +2,17 @@ ...@@ -2,39 +2,17 @@
#define CK_ARRAY_HELPER_HPP #define CK_ARRAY_HELPER_HPP
#include "array.hpp" #include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
namespace ck { namespace ck {
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
{
return Array<X, sizeof...(xs) + 1>{{x, xs...}};
}
template <typename Arr, typename Picks> template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{ {
return ArrayElementPicker<Arr, Picks>(a); return ArrayElementPicker<Arr, Picks>(a);
} }
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
constexpr auto zero_array = to_array(zero_sequence);
return zero_array;
}
template <typename TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/) Sequence<IRs...> /*new2old*/)
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include "array.hpp" #include "array.hpp"
#include "array_helper.hpp" #include "array_helper.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "config.hpp" #include "config.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "functional.hpp" #include "functional.hpp"
......
...@@ -15,195 +15,15 @@ __host__ __device__ void print_array(const char* s, T a) ...@@ -15,195 +15,15 @@ __host__ __device__ void print_array(const char* s, T a)
if constexpr(is_same<data_type, uint32_t>{}) if constexpr(is_same<data_type, uint32_t>{})
{ {
if constexpr(nsize == 0) printf("%s size %u, {", s, nsize);
{ static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); });
printf("%s size %u\n", s, nsize); printf("}\n");
}
else if constexpr(nsize == 1)
{
printf("%s size %u, {%u}\n", s, nsize, a[0]);
}
else if constexpr(nsize == 2)
{
printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]);
}
else if constexpr(nsize == 3)
{
printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]);
}
else if constexpr(nsize == 4)
{
printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]);
}
else if constexpr(nsize == 5)
{
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
}
else if constexpr(nsize == 6)
{
printf(
"%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
}
else if constexpr(nsize == 7)
{
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
}
else if constexpr(nsize == 8)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
}
else if constexpr(nsize == 9)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
}
else if constexpr(nsize == 10)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
}
else
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); });
printf("}\n");
}
} }
else if constexpr(is_same<data_type, int32_t>{}) else if constexpr(is_same<data_type, int32_t>{})
{ {
if constexpr(nsize == 0) printf("%s size %d, {", s, nsize);
{ static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); });
printf("%s size %d\n", s, nsize); printf("}\n");
}
else if constexpr(nsize == 1)
{
printf("%s size %d, {%d}\n", s, nsize, a[0]);
}
else if constexpr(nsize == 2)
{
printf("%s size %d, {%d %d}\n", s, nsize, a[0], a[1]);
}
else if constexpr(nsize == 3)
{
printf("%s size %d, {%d %d %d}\n", s, nsize, a[0], a[1], a[2]);
}
else if constexpr(nsize == 4)
{
printf("%s size %d, {%d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3]);
}
else if constexpr(nsize == 5)
{
printf("%s size %d, {%d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
}
else if constexpr(nsize == 6)
{
printf(
"%s size %d, {%d %d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
}
else if constexpr(nsize == 7)
{
printf("%s size %d, {%d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
}
else if constexpr(nsize == 8)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
}
else if constexpr(nsize == 9)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
}
else if constexpr(nsize == 10)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
}
else
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); });
printf("}\n");
}
} }
} }
......
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace ck {
template <typename TData, index_t NSize>
struct StaticallyIndexedArray
{
};
template <typename TData>
struct StaticallyIndexedArray<TData, 0> : Tuple<>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 1> : Tuple<TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 2> : Tuple<TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 3> : Tuple<TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 4> : Tuple<TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 5> : Tuple<TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 6> : Tuple<TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 7> : Tuple<TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 8>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 9>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 10>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 11>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 12>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 13> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 14> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 15> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 16> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 17> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 18> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 19> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 20> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 21> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 22> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData, index_t NSize, typename X>
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<TData, NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename TData, index_t NSize, typename X>
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<TData, NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <typename TData, index_t NSize, typename T>
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<TData, NSize>& a,
const T& b)
{
using type = StaticallyIndexedArray<TData, NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; });
return r;
}
template <typename TData, index_t NSize, typename T>
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<TData, NSize>& a,
const T& b)
{
using type = StaticallyIndexedArray<TData, NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; });
return r;
}
} // namespace ck
#endif
...@@ -89,6 +89,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -89,6 +89,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{ {
} }
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const __host__ __device__ constexpr const auto& At(Number<I>) const
{ {
...@@ -102,6 +104,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -102,6 +104,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
static_assert(I < base::Size(), "wrong! out of range"); static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{}); return base::GetElementByKey(detail::TupleElementKey<I>{});
} }
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
}; };
} // namespace ck } // namespace ck
......
conv_bwd_data_driver.cpp
\ No newline at end of file
...@@ -561,7 +561,7 @@ int main(int argc, char* argv[]) ...@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment