Commit 3bc39592 authored by Chao Liu's avatar Chao Liu
Browse files

refactor MultiIndex, Tuple

parent c52c308d
...@@ -115,10 +115,7 @@ struct DummyDynamicTransform_v2_1 ...@@ -115,10 +115,7 @@ struct DummyDynamicTransform_v2_1
MultiIndex<2> idx; MultiIndex<2> idx;
// initialize idx // initialize idx
for(index_t i = 0; i < 2; ++i) static_for<0, 2, 1>{}([&](auto i) { idx(i) = p_wei_global[get_thread_local_1d_id() + i]; });
{
idx(i) = p_wei_global[get_thread_local_1d_id() + i];
}
auto in_gemmk_gemmn_coord = auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, idx); make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, idx);
...@@ -148,29 +145,34 @@ struct DummyDynamicTransform_v2_1 ...@@ -148,29 +145,34 @@ struct DummyDynamicTransform_v2_1
const MultiIndex<2> in_left_pads, const MultiIndex<2> in_left_pads,
const MultiIndex<2> in_right_pads) const const MultiIndex<2> in_right_pads) const
{ {
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0); constexpr auto i0 = Number<0>{};
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1); constexpr auto i1 = Number<1>{};
const index_t K = out_n_k_ho_wo_global_desc.GetLength(1); constexpr auto i2 = Number<2>{};
constexpr auto i3 = Number<3>{};
const index_t Y = wei_k_c_y_x_global_desc.GetLength(2); const index_t N = in_n_c_hi_wi_global_desc.GetLength(i0);
const index_t X = wei_k_c_y_x_global_desc.GetLength(3); const index_t C = in_n_c_hi_wi_global_desc.GetLength(i1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(i1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(2); const index_t Y = wei_k_c_y_x_global_desc.GetLength(i2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(3); const index_t X = wei_k_c_y_x_global_desc.GetLength(i3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(2); const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(i2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(3); const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(i3);
const index_t ConvStrideH = conv_strides[0]; const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(i2);
const index_t ConvStrideW = conv_strides[1]; const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(i3);
const index_t ConvDilationH = conv_dilations[0]; const index_t ConvStrideH = conv_strides[i0];
const index_t ConvDilationW = conv_dilations[1]; const index_t ConvStrideW = conv_strides[i1];
const index_t InLeftPadH = in_left_pads[0]; const index_t ConvDilationH = conv_dilations[i0];
const index_t InLeftPadW = in_left_pads[1]; const index_t ConvDilationW = conv_dilations[i1];
const index_t InRightPadH = in_right_pads[0];
const index_t InRightPadW = in_right_pads[1]; const index_t InLeftPadH = in_left_pads[i0];
const index_t InLeftPadW = in_left_pads[i1];
const index_t InRightPadH = in_right_pads[i0];
const index_t InRightPadW = in_right_pads[i1];
#if 0 #if 0
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2(
...@@ -211,10 +213,7 @@ struct DummyDynamicTransform_v2_1 ...@@ -211,10 +213,7 @@ struct DummyDynamicTransform_v2_1
MultiIndex<4> idx; MultiIndex<4> idx;
// initialize idx // initialize idx
for(index_t i = 0; i < 4; ++i) static_for<0, 4, 1>{}([&](auto i) { idx(i) = p_wei_global[get_thread_local_1d_id() + i]; });
{
idx(i) = p_wei_global[get_thread_local_1d_id() + i];
}
#if 0 #if 0
const index_t niter = p_wei_global[10]; const index_t niter = p_wei_global[10];
......
...@@ -14,7 +14,7 @@ struct DynamicPassThrough ...@@ -14,7 +14,7 @@ struct DynamicPassThrough
const UpperIndex up_lengths_; const UpperIndex up_lengths_;
__host__ __device__ explicit constexpr DynamicPassThrough(const index_t& low_length) __host__ __device__ explicit constexpr DynamicPassThrough(const index_t& low_length)
: up_lengths_{low_length} : up_lengths_{make_multi_index(low_length)}
{ {
} }
...@@ -74,7 +74,7 @@ struct DynamicLeftPad ...@@ -74,7 +74,7 @@ struct DynamicLeftPad
__host__ __device__ explicit constexpr DynamicLeftPad(const index_t& low_length, __host__ __device__ explicit constexpr DynamicLeftPad(const index_t& low_length,
const index_t& left_pad) const index_t& left_pad)
: up_lengths_{low_length + left_pad}, left_pad_{left_pad} : up_lengths_{make_multi_index(low_length + left_pad)}, left_pad_{left_pad}
{ {
} }
...@@ -137,7 +137,9 @@ struct DynamicRightPad ...@@ -137,7 +137,9 @@ struct DynamicRightPad
__host__ __device__ explicit constexpr DynamicRightPad(const index_t& low_length, __host__ __device__ explicit constexpr DynamicRightPad(const index_t& low_length,
const index_t& right_pad) const index_t& right_pad)
: up_lengths_{low_length + right_pad}, low_length_{low_length}, right_pad_{right_pad} : up_lengths_{make_multi_index(low_length + right_pad)},
low_length_{low_length},
right_pad_{right_pad}
{ {
} }
......
...@@ -9,12 +9,20 @@ namespace ck { ...@@ -9,12 +9,20 @@ namespace ck {
template <index_t N> template <index_t N>
using MultiIndex = Array<index_t, N>; using MultiIndex = Array<index_t, N>;
#if 1 // debug
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs... xs) __host__ __device__ constexpr auto make_multi_index(Xs... xs)
{ {
return make_array<index_t>(xs...); return make_array<index_t>(xs...);
} }
#else #else
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(const Xs&... xs)
{
return make_array(xs...);
}
#endif
#else
template <index_t N> template <index_t N>
using MultiIndex = StaticallyIndexedArray<index_t, N>; using MultiIndex = StaticallyIndexedArray<index_t, N>;
......
...@@ -45,11 +45,19 @@ struct Array<TData, 0> ...@@ -45,11 +45,19 @@ struct Array<TData, 0>
__host__ __device__ static constexpr index_t Size() { return 0; } __host__ __device__ static constexpr index_t Size() { return 0; }
}; };
#if 1
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs) __host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
{ {
return Array<X, sizeof...(Xs) + 1>{{x, static_cast<X>(xs)...}}; return Array<X, sizeof...(Xs) + 1>{{x, static_cast<X>(xs)...}};
} }
#else
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{
return Array<remove_cv_t<remove_reference_t<X>>, sizeof...(Xs) + 1>(x, xs...);
}
#endif
// make empty array // make empty array
template <typename X> template <typename X>
......
...@@ -21,7 +21,7 @@ using remove_reference_t = typename std::remove_reference<T>::type; ...@@ -21,7 +21,7 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T> template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
template <class T> template <typename T>
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
{ {
return static_cast<typename std::remove_reference<T>::type&&>(t); return static_cast<typename std::remove_reference<T>::type&&>(t);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment