Commit 4d70c71b authored by Chao Liu's avatar Chao Liu
Browse files

refactor array

parent 5a2498d1
...@@ -13,39 +13,39 @@ __host__ __device__ constexpr auto ...@@ -13,39 +13,39 @@ __host__ __device__ constexpr auto
map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc, map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc,
const InDesc& in_n_c_hi_wi_global_desc, const InDesc& in_n_c_hi_wi_global_desc,
const OutDesc& out_n_k_ho_wo_global_desc, const OutDesc& out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2>& conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2>& conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2>& in_left_pads,
const Array<index_t, 2> in_right_pads) const MultiIndex<2>& in_right_pads)
{ {
constexpr auto i0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto i1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto i2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto i3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(i0); const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(i1); const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(i1); const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(i2); const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(i3); const index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(i2); const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(i3); const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(i2); const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(i3); const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[i0]; const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[i1]; const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[i0]; const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[i1]; const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[i0]; const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[i1]; const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[i0]; const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[i1]; const index_t InRightPadW = in_right_pads[I1];
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
...@@ -64,8 +64,8 @@ map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc, ...@@ -64,8 +64,8 @@ map_convolution_into_gemm_v1(const WeiDesc& wei_k_c_y_x_global_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(i2); const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(I2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(i3); const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(I3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
...@@ -97,55 +97,60 @@ struct DummyDynamicTransform_v1 ...@@ -97,55 +97,60 @@ struct DummyDynamicTransform_v1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2>& conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2>& conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2>& in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2>& in_right_pads) const
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
#if 1 #if 1
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0); const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1); const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(1); const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(2); const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(3); const index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(2); const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(3); const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(2); const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(3); const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[0]; const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[1]; const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[0]; const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[1]; const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[0]; const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[1]; const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[0]; const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[1]; const index_t InRightPadW = in_right_pads[I1];
#else #else
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0); const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1); const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t Y = 3; const index_t Y = 3;
const index_t X = 3; const index_t X = 3;
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(2); const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(3); const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[0]; const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[1]; const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[0]; const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[1]; const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[0]; const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[1]; const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[0]; const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[1]; const index_t InRightPadW = in_right_pads[I1];
#endif #endif
// define transform // define transform
...@@ -537,10 +542,10 @@ struct DummyDynamicTransform_v1 ...@@ -537,10 +542,10 @@ struct DummyDynamicTransform_v1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2>& conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2>& conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2>& in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2>& in_right_pads) const
{ {
const auto transformed_tensor_descs = const auto transformed_tensor_descs =
map_convolution_into_gemm_v1(wei_k_c_y_x_global_desc, map_convolution_into_gemm_v1(wei_k_c_y_x_global_desc,
...@@ -598,10 +603,10 @@ struct DummyDynamicTransform_v1 ...@@ -598,10 +603,10 @@ struct DummyDynamicTransform_v1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2>& conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2>& conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2>& in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2>& in_right_pads) const
{ {
Run_2(p_wei_global, Run_2(p_wei_global,
p_in_global, p_in_global,
......
...@@ -31,9 +31,17 @@ struct DynamicNativeTensorDescriptor ...@@ -31,9 +31,17 @@ struct DynamicNativeTensorDescriptor
__host__ __device__ constexpr auto GetStrides() const { return strides_; } __host__ __device__ constexpr auto GetStrides() const { return strides_; }
__host__ __device__ constexpr index_t GetLength(index_t idim) const { return lengths_[idim]; } template <index_t IDim>
__host__ __device__ constexpr index_t GetLength(Number<IDim>) const
{
return lengths_[Number<IDim>{}];
}
__host__ __device__ constexpr index_t GetStride(index_t idim) const { return strides_[idim]; } template <index_t IDim>
__host__ __device__ constexpr index_t GetStride(Number<IDim>) const
{
return strides_[Number<IDim>{}];
}
__host__ __device__ constexpr index_t GetElementSize() const __host__ __device__ constexpr index_t GetElementSize() const
{ {
...@@ -44,11 +52,7 @@ struct DynamicNativeTensorDescriptor ...@@ -44,11 +52,7 @@ struct DynamicNativeTensorDescriptor
{ {
index_t space = 1; index_t space = 1;
#pragma unroll static_for<0, NDim, 1>{}([&](auto i) { space += (GetLength(i) - 1) * GetStride(i); });
for(index_t i = 0; i < NDim; ++i)
{
space += (GetLength(i) - 1) * GetStride(i);
}
return space; return space;
} }
...@@ -58,11 +62,7 @@ struct DynamicNativeTensorDescriptor ...@@ -58,11 +62,7 @@ struct DynamicNativeTensorDescriptor
{ {
index_t offset = 0; index_t offset = 0;
#pragma unroll static_for<0, NDim, 1>{}([&](auto i) { offset += idx[i] * GetStride(i); });
for(index_t i = 0; i < NDim; ++i)
{
offset += idx[i] * GetStride(i);
}
return offset; return offset;
} }
...@@ -78,11 +78,8 @@ struct DynamicNativeTensorDescriptor ...@@ -78,11 +78,8 @@ struct DynamicNativeTensorDescriptor
{ {
bool flag = true; bool flag = true;
#pragma unroll static_for<0, NDim, 1>{}(
for(index_t i = 0; i < NDim; ++i) [&](auto i) { flag = flag && idx[i] >= 0 && idx[i] < GetLength(i); });
{
flag = flag && idx[i] >= 0 && idx[i] < GetLength(i);
}
return flag; return flag;
} }
...@@ -139,7 +136,7 @@ struct DynamicTransformedTensorDescriptor ...@@ -139,7 +136,7 @@ struct DynamicTransformedTensorDescriptor
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto operator()(Xs... xs) const __host__ __device__ constexpr auto operator()(Xs... xs) const
{ {
return merge_arrays(xs...); return array_cat(xs...);
} }
}; };
...@@ -306,11 +303,8 @@ struct DynamicTransformedTensorDescriptor ...@@ -306,11 +303,8 @@ struct DynamicTransformedTensorDescriptor
{ {
bool flag = true; bool flag = true;
#pragma unroll static_for<0, NDimUp, 1>{}(
for(index_t i = 0; i < NDimUp; ++i) [&](auto i) { flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLength(i); });
{
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLength(i);
}
return flag; return flag;
} }
......
...@@ -10,9 +10,9 @@ template <typename Lengths, typename Strides> ...@@ -10,9 +10,9 @@ template <typename Lengths, typename Strides>
__host__ __device__ constexpr auto make_dynamic_native_tensor_descriptor(const Lengths& lengths, __host__ __device__ constexpr auto make_dynamic_native_tensor_descriptor(const Lengths& lengths,
const Strides& strides) const Strides& strides)
{ {
static_assert(Lengths::GetSize() == Strides::GetSize(), "wrong! Size not the same"); static_assert(Lengths::Size() == Strides::Size(), "wrong! Size not the same");
return DynamicNativeTensorDescriptor<Lengths::GetSize()>(lengths, strides); return DynamicNativeTensorDescriptor<Lengths::Size()>(lengths, strides);
} }
template <typename LowTensorDescriptor, template <typename LowTensorDescriptor,
......
...@@ -340,7 +340,7 @@ struct DynamicTensorCoordinateStep_v2 ...@@ -340,7 +340,7 @@ struct DynamicTensorCoordinateStep_v2
#endif #endif
}; };
// TODO: Fix this! This is insane, to use an ugly struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used // doesn't have constructor, and to put it outside the scope where it is used
// (transform_dynamic_tensor_descriptor_v2) because template cannot be defined inside a function // (transform_dynamic_tensor_descriptor_v2) because template cannot be defined inside a function
// template // template
...@@ -538,22 +538,25 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten ...@@ -538,22 +538,25 @@ __host__ __device__ void move_dynamic_tensor_coordinate_v2(const TensorDesc& ten
idx_hidden_pick_visible += coord_step.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &idx_diff_hidden](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
const auto& tran = tensor_desc.GetTransforms().At(itran); if(coord_step.do_transforms_[itran])
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); {
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
// this const is for ArrayElementPicker, Array itself may not be const // this const is for ArrayElementPicker, Array itself may not be const
const auto idx_up = pick_array_element(idx_hidden, dims_up); const auto idx_up = pick_array_element(idx_hidden, dims_up);
auto idx_low = pick_array_element(idx_hidden, dims_low); auto idx_low = pick_array_element(idx_hidden, dims_low);
const auto idx_diff_up = pick_array_element(idx_diff_hidden, dims_up); const auto idx_diff_up = pick_array_element(idx_diff_hidden, dims_up);
auto idx_diff_low = pick_array_element(idx_diff_hidden, dims_low); auto idx_diff_low = pick_array_element(idx_diff_hidden, dims_low);
tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up); tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up);
// update idx_low // update idx_low
idx_low += idx_diff_low; idx_low += idx_diff_low;
}
}); });
} }
......
...@@ -59,17 +59,5 @@ __host__ __device__ constexpr auto make_array() ...@@ -59,17 +59,5 @@ __host__ __device__ constexpr auto make_array()
return Array<X, 0>{}; return Array<X, 0>{};
} }
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto push_back(Array<TData, NSize>& a, const TData& x)
{
Array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r(Number<NSize>{}) = x;
return r;
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -97,5 +97,11 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y, ...@@ -97,5 +97,11 @@ __host__ __device__ constexpr auto operator-=(ArrayElementPicker<Arr, Picks>& y,
return y; return y;
} }
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_ARRAY_HELPER_HPP #ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP #define CK_ARRAY_HELPER_HPP
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "array.hpp" #include "array.hpp"
#include "array_helper.hpp"
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "array_element_picker.hpp" #include "array_element_picker.hpp"
namespace ck { namespace ck {
template <typename Arr, typename Picks> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) __host__ __device__ constexpr auto push_back(const Array<TData, NSize>& a, const TData& x)
{ {
return ArrayElementPicker<Arr, Picks>(a); Array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r(Number<NSize>{}) = x;
return r;
} }
template <typename TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
...@@ -63,20 +74,6 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData ...@@ -63,20 +74,6 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array; return new_array;
} }
template <typename TData, index_t NSize, typename ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{
Array<TData, ExtractSeq::GetSize()> new_array;
constexpr index_t new_size = ExtractSeq::GetSize();
static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
return new_array;
}
// emulate constepxr lambda for array // emulate constepxr lambda for array
template <typename F, typename X, typename Y, typename Z> template <typename F, typename X, typename Y, typename Z>
struct lambda_array_math struct lambda_array_math
...@@ -201,31 +198,25 @@ reverse_exclusive_scan_on_array(const Array<TData, NSize>& x, Reduce f, TData in ...@@ -201,31 +198,25 @@ reverse_exclusive_scan_on_array(const Array<TData, NSize>& x, Reduce f, TData in
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto merge_arrays(const X& x, const Ys&... ys) __host__ __device__ constexpr auto container_cat(const X& x, const Ys&... ys)
{ {
return merge_arrays(x, merge_arrays(ys...)); return container_cat(x, container_cat(ys...));
} }
template <typename T, index_t NX, index_t NY> template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto merge_arrays(const Array<T, NX>& x, const Array<T, NY>& y) __host__ __device__ constexpr auto container_cat(const Array<T, NX>& x, const Array<T, NY>& y)
{ {
Array<T, NX + NY> z; Array<T, NX + NY> z;
for(index_t i = 0; i < NX; ++i) static_for<0, NX, 1>{}([&](auto i) { z(i) = x[i]; });
{
z(i) = x[i];
}
for(index_t i = 0; i < NY; ++i) static_for<0, NY, 1>{}([&](auto i) { z(i + Number<NX>{}) = y[i]; });
{
z(i + NX) = y[i];
}
return z; return z;
} }
template <typename X> template <typename T, index_t N>
__host__ __device__ constexpr auto merge_arrays(const X& x) __host__ __device__ constexpr auto container_cat(const Array<T, N>& x)
{ {
return x; return x;
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_PRINT_HPP #define CK_PRINT_HPP
#include "array.hpp" #include "array.hpp"
#include "statically_indexed_array.hpp"
#include "array_helper.hpp" #include "array_helper.hpp"
#include "sequence.hpp" #include "sequence.hpp"
...@@ -19,7 +20,7 @@ __host__ __device__ void print_array(const char* s, T a) ...@@ -19,7 +20,7 @@ __host__ __device__ void print_array(const char* s, T a)
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", a[i]); });
printf("}\n"); printf("}\n");
} }
else if constexpr(is_same<data_type, int32_t>{}) else if constexpr(true)
{ {
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", a[i]); });
...@@ -39,7 +40,7 @@ __host__ __device__ void print_array_v2(const char* s, T a) ...@@ -39,7 +40,7 @@ __host__ __device__ void print_array_v2(const char* s, T a)
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n"); printf("}\n");
} }
else if constexpr(is_same<data_type, int32_t>{}) else if constexpr(true)
{ {
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
......
...@@ -28,17 +28,17 @@ void device_dummy_dynamic_transform_v1(InDesc, ...@@ -28,17 +28,17 @@ void device_dummy_dynamic_transform_v1(InDesc,
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
const auto in_nchw_desc = make_dynamic_native_tensor_descriptor(to_array(InDesc::GetLengths()), const auto in_nchw_desc = make_dynamic_native_tensor_descriptor(
to_array(InDesc::GetStrides())); to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides()));
const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor( const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor(
to_array(WeiDesc::GetLengths()), to_array(WeiDesc::GetStrides())); to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides()));
const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor( const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor(
to_array(OutDesc::GetLengths()), to_array(OutDesc::GetStrides())); to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides()));
const auto conv_strides = to_array(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_array(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_array(InLeftPads{}); const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_array(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
{ {
const auto tensor_descs = map_convolution_into_gemm_v1(wei_kcyx_desc, const auto tensor_descs = map_convolution_into_gemm_v1(wei_kcyx_desc,
......
...@@ -58,10 +58,13 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -58,10 +58,13 @@ void device_dummy_dynamic_transform_v2(InDesc,
const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2( const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2(
in_gemmk_gemmn_global_desc, make_multi_index(1, 0)); in_gemmk_gemmn_global_desc, make_multi_index(1, 0));
print_array("do_tansforms: ", in_gemmk_gemmn_coord_step.do_transforms_);
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
printf("iter %d\n", iter); printf("iter %d\n", iter);
print_array("idx: ", in_gemmk_gemmn_coord.GetIndex()); print_array("idx: ", in_gemmk_gemmn_coord.GetIndex());
print_array("hidden idx: ", in_gemmk_gemmn_coord.GetHiddenIndex());
printf("offset: %d\n", in_gemmk_gemmn_coord.GetOffset()); printf("offset: %d\n", in_gemmk_gemmn_coord.GetOffset());
printf("\n"); printf("\n");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment