Commit 674c405f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent ffa7e4be
......@@ -17,29 +17,34 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
const Array<index_t, 2> in_left_pads,
const Array<index_t, 2> in_right_pads)
{
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(1);
constexpr auto i0 = Number<0>{};
constexpr auto i1 = Number<1>{};
constexpr auto i2 = Number<2>{};
constexpr auto i3 = Number<3>{};
const index_t Y = wei_k_c_y_x_global_desc.GetLength(2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(3);
const index_t N = in_n_c_hi_wi_global_desc.GetLength(i0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(i1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(i1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(i2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(i3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(3);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(i2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(i3);
const index_t ConvStrideH = conv_strides[0];
const index_t ConvStrideW = conv_strides[1];
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(i2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(i3);
const index_t ConvDilationH = conv_dilations[0];
const index_t ConvDilationW = conv_dilations[1];
const index_t ConvStrideH = conv_strides[i0];
const index_t ConvStrideW = conv_strides[i1];
const index_t InLeftPadH = in_left_pads[0];
const index_t InLeftPadW = in_left_pads[1];
const index_t InRightPadH = in_right_pads[0];
const index_t InRightPadW = in_right_pads[1];
const index_t ConvDilationH = conv_dilations[i0];
const index_t ConvDilationW = conv_dilations[i1];
const index_t InLeftPadH = in_left_pads[i0];
const index_t InLeftPadW = in_left_pads[i1];
const index_t InRightPadH = in_right_pads[i0];
const index_t InRightPadW = in_right_pads[i1];
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2(
......@@ -58,8 +63,8 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(3);
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(i2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(i3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor_v2(
in_n_c_hip_wip_global_desc,
......
......@@ -31,7 +31,7 @@ struct DynamicPassThrough
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0];
idx_low(Number<0>{}) = idx_up[Number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -44,7 +44,7 @@ struct DynamicPassThrough
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0];
idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -92,7 +92,7 @@ struct DynamicLeftPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0] - left_pad_;
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -106,7 +106,7 @@ struct DynamicLeftPad
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0];
idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -120,7 +120,7 @@ struct DynamicLeftPad
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[0] >= left_pad_);
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
}
};
......@@ -158,7 +158,7 @@ struct DynamicRightPad
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(0) = idx_up[0];
idx_low(Number<0>{}) = idx_up[Number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -172,7 +172,7 @@ struct DynamicRightPad
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(0) = idx_diff_up[0];
idx_diff_low(Number<0>{}) = idx_diff_up[Number<0>{}];
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -186,7 +186,7 @@ struct DynamicRightPad
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[0] < low_length_);
return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
}
};
......@@ -228,13 +228,11 @@ struct DynamicEmbed
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(0) = coefficients_[NDimUp];
idx_low(Number<0>{}) = coefficients_[Number<NDimUp>{}];
#pragma unroll
for(index_t i = 0; i < NDimUp; ++i)
{
idx_low(0) += idx_up[i] * coefficients_[i];
}
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
});
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -247,13 +245,10 @@ struct DynamicEmbed
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(0) = 0;
idx_diff_low(Number<0>{}) = 0;
#pragma unroll
for(index_t i = 0; i < NDimUp; ++i)
{
idx_diff_low(0) += idx_diff_up[i] * coefficients_[i];
}
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -310,16 +305,14 @@ struct DynamicMerge
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[0];
index_t tmp = idx_up[Number<0>{}];
#pragma unroll
for(index_t i = 0; i < NDimLow - 1; ++i)
{
idx_low(i) = tmp / low_lengths_scan_[i];
tmp -= idx_low[i] * low_lengths_scan_[i];
}
static_for<0, NDimLow - 1, 1>{}([&idx_low, &tmp, this](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp -= idx_low[i] * this->low_lengths_scan_[i];
});
idx_low(NDimLow - 1) = tmp;
idx_low(Number<NDimLow - 1>{}) = tmp;
}
// idx_diff_low depends on idx_low_old, so idx_low need to be up-to-date
......@@ -336,15 +329,13 @@ struct DynamicMerge
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
#if 1
#if 0
// I only want to do this check, if idx_diff_up is know at compile-time
if(idx_diff_up[0] == 0)
{
#pragma unroll
for(index_t i = 0; i < NDimLow; ++i)
if(idx_diff_up[Number<0>{}] == 0)
{
static_for<0, NDimLow, 1>{}([&idx_diff_low](auto i){
idx_diff_low(i) = 0;
}
});
return;
}
......@@ -370,9 +361,7 @@ struct DynamicMerge
// do not need to check the first dimension
index_t carry = 0;
#pragma unroll
for(index_t i = NDimLow - 1; i > 0; --i)
{
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
// this should be saved in SGPR as well
index_t idx_low_length_minus_idx_diff_low_const =
low_lengths_[i] - idx_diff_low_const[i];
......@@ -401,9 +390,9 @@ struct DynamicMerge
#if 0
carry = do_borrow ? -1 : carry;
#endif
}
});
idx_diff_low(0) = idx_diff_low_const[0] + carry;
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
......@@ -453,13 +442,10 @@ struct DynamicUnMerge
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low(0) = idx_up[NDimUp];
idx_low(Number<0>{}) = idx_up[Number<NDimUp>{}];
#pragma unroll
for(index_t i = 0; i < NDimUp - 1; ++i)
{
idx_low(0) += idx_up[i] * up_lengths_scan_[i];
}
static_for<0, NDimUp - 1, 1>{}(
[&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -512,7 +498,7 @@ struct DynamicFreeze
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(0) = low_idx_;
idx_low(Number<0>{}) = low_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -521,7 +507,7 @@ struct DynamicFreeze
const LowIdx& /* idx_low_old */,
const UpIdx& /* idx_up_old */)
{
idx_diff_low(0) = index_t{0};
idx_diff_low(Number<0>{}) = index_t{Number<0>{}};
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......
......@@ -105,9 +105,10 @@ struct DynamicTensorDescriptor_v2
return GetNumOfVisibleDimension();
}
__host__ __device__ constexpr index_t GetLength(index_t idim) const
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength(Number<IDim>) const
{
return visible_lengths_[idim];
return visible_lengths_[Number<IDim>{}];
}
__host__ __device__ constexpr const auto& GetLengths() const { return visible_lengths_; }
......
......@@ -161,82 +161,29 @@ struct Array
}
};
// Arr: Array
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
return Array<X, sizeof...(xs) + 1>{{x, xs...}};
}
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
__host__ __device__ constexpr const data_type& operator[](index_t i) const
{
index_t ip = Picks{}[i];
return mArray[ip];
}
__host__ __device__ constexpr data_type& operator()(index_t i)
{
index_t ip = Picks{}[i];
return mArray(ip);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
template <typename T>
__host__ __device__ constexpr auto operator+=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) += a[i]; });
return *this;
}
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
template <typename T>
__host__ __device__ constexpr auto operator-=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) -= a[i]; });
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return *this;
}
return y;
}
private:
Arr& mArray;
};
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
constexpr auto zero_array = to_array(zero_sequence);
return zero_array;
}
} // namespace ck
#endif
#ifndef CK_ARRAY_ELEMENT_PICKER_HPP
#define CK_ARRAY_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "sequence.hpp"
namespace ck {
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ArrayElementPicker
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
private:
Arr& mArray;
};
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator+=(ArrayElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ArrayElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
} // namespace ck
#endif
......@@ -2,39 +2,17 @@
#define CK_ARRAY_HELPER_HPP
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
namespace ck {
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
{
return Array<X, sizeof...(xs) + 1>{{x, xs...}};
}
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
constexpr auto zero_array = to_array(zero_sequence);
return zero_array;
}
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/)
......
......@@ -3,6 +3,8 @@
#include "array.hpp"
#include "array_helper.hpp"
#include "statically_indexed_array.hpp"
#include "array_element_picker.hpp"
#include "config.hpp"
#include "float_type.hpp"
#include "functional.hpp"
......
......@@ -14,197 +14,17 @@ __host__ __device__ void print_array(const char* s, T a)
constexpr index_t nsize = a.Size();
if constexpr(is_same<data_type, uint32_t>{})
{
if constexpr(nsize == 0)
{
printf("%s size %u\n", s, nsize);
}
else if constexpr(nsize == 1)
{
printf("%s size %u, {%u}\n", s, nsize, a[0]);
}
else if constexpr(nsize == 2)
{
printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]);
}
else if constexpr(nsize == 3)
{
printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]);
}
else if constexpr(nsize == 4)
{
printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]);
}
else if constexpr(nsize == 5)
{
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
}
else if constexpr(nsize == 6)
{
printf(
"%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
}
else if constexpr(nsize == 7)
{
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
}
else if constexpr(nsize == 8)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
}
else if constexpr(nsize == 9)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
}
else if constexpr(nsize == 10)
{
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
}
else
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); });
printf("}\n");
}
}
else if constexpr(is_same<data_type, int32_t>{})
{
if constexpr(nsize == 0)
{
printf("%s size %d\n", s, nsize);
}
else if constexpr(nsize == 1)
{
printf("%s size %d, {%d}\n", s, nsize, a[0]);
}
else if constexpr(nsize == 2)
{
printf("%s size %d, {%d %d}\n", s, nsize, a[0], a[1]);
}
else if constexpr(nsize == 3)
{
printf("%s size %d, {%d %d %d}\n", s, nsize, a[0], a[1], a[2]);
}
else if constexpr(nsize == 4)
{
printf("%s size %d, {%d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3]);
}
else if constexpr(nsize == 5)
{
printf("%s size %d, {%d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
}
else if constexpr(nsize == 6)
{
printf(
"%s size %d, {%d %d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
}
else if constexpr(nsize == 7)
{
printf("%s size %d, {%d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
}
else if constexpr(nsize == 8)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
}
else if constexpr(nsize == 9)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
}
else if constexpr(nsize == 10)
{
printf("%s size %d, {%d %d %d %d %d %d %d %d %d %d}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
}
else
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); });
printf("}\n");
}
}
}
template <typename T>
......
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace ck {
template <typename TData, index_t NSize>
struct StaticallyIndexedArray
{
};
template <typename TData>
struct StaticallyIndexedArray<TData, 0> : Tuple<>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 1> : Tuple<TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 2> : Tuple<TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 3> : Tuple<TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 4> : Tuple<TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 5> : Tuple<TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 6> : Tuple<TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 7> : Tuple<TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 8>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 9>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 10>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 11>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 12>
: Tuple<TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData, TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 13> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 14> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 15> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 16> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 17> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 18> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 19> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 20> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 21> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData>
struct StaticallyIndexedArray<TData, 22> : Tuple<TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData,
TData>
{
using data_type = TData;
};
template <typename TData, index_t NSize, typename X>
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<TData, NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename TData, index_t NSize, typename X>
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<TData, NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <typename TData, index_t NSize, typename T>
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<TData, NSize>& a,
const T& b)
{
using type = StaticallyIndexedArray<TData, NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; });
return r;
}
template <typename TData, index_t NSize, typename T>
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<TData, NSize>& a,
const T& b)
{
using type = StaticallyIndexedArray<TData, NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; });
return r;
}
} // namespace ck
#endif
......@@ -89,6 +89,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
......@@ -102,6 +104,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
};
} // namespace ck
......
conv_bwd_data_driver.cpp
\ No newline at end of file
......@@ -561,7 +561,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment