Commit c52c308d authored by Chao Liu's avatar Chao Liu
Browse files

refactor MultiIndex

parent 2a358d50
...@@ -555,11 +555,7 @@ struct DummyDynamicTransform_v1 ...@@ -555,11 +555,7 @@ struct DummyDynamicTransform_v1
MultiIndex<2> idx; MultiIndex<2> idx;
// initialize idx static_for<0, 2, 1>{}([&](auto i) { idx(i) = p_wei_global[get_thread_local_1d_id() + i]; });
for(index_t i = 0; i < 2; ++i)
{
idx(i) = p_wei_global[get_thread_local_1d_id() + i];
}
const index_t niter = p_wei_global[10]; const index_t niter = p_wei_global[10];
......
...@@ -12,10 +12,10 @@ __host__ __device__ constexpr auto ...@@ -12,10 +12,10 @@ __host__ __device__ constexpr auto
map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc, map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
const InDesc& in_n_c_hi_wi_global_desc, const InDesc& in_n_c_hi_wi_global_desc,
const OutDesc& out_n_k_ho_wo_global_desc, const OutDesc& out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2> conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2> conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2> in_left_pads,
const Array<index_t, 2> in_right_pads) const MultiIndex<2> in_right_pads)
{ {
constexpr auto i0 = Number<0>{}; constexpr auto i0 = Number<0>{};
constexpr auto i1 = Number<1>{}; constexpr auto i1 = Number<1>{};
...@@ -96,10 +96,10 @@ struct DummyDynamicTransform_v2_1 ...@@ -96,10 +96,10 @@ struct DummyDynamicTransform_v2_1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2> conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2> conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2> in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2> in_right_pads) const
{ {
const auto transformed_tensor_descs = const auto transformed_tensor_descs =
map_convolution_into_gemm_v2(move(wei_k_c_y_x_global_desc), map_convolution_into_gemm_v2(move(wei_k_c_y_x_global_desc),
...@@ -124,7 +124,7 @@ struct DummyDynamicTransform_v2_1 ...@@ -124,7 +124,7 @@ struct DummyDynamicTransform_v2_1
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, idx); make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, idx);
const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2( const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2(
in_gemmk_gemmn_global_desc, MultiIndex<2>{{1, 0}}); in_gemmk_gemmn_global_desc, make_multi_index(1, 0));
#pragma unroll #pragma unroll
for(index_t i = 0; i < 10; ++i) for(index_t i = 0; i < 10; ++i)
...@@ -143,10 +143,10 @@ struct DummyDynamicTransform_v2_1 ...@@ -143,10 +143,10 @@ struct DummyDynamicTransform_v2_1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2> conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2> conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2> in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2> in_right_pads) const
{ {
const index_t N = in_n_c_hi_wi_global_desc.GetLength(0); const index_t N = in_n_c_hi_wi_global_desc.GetLength(0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(1); const index_t C = in_n_c_hi_wi_global_desc.GetLength(1);
...@@ -262,10 +262,10 @@ struct DummyDynamicTransform_v2_1 ...@@ -262,10 +262,10 @@ struct DummyDynamicTransform_v2_1
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const Array<index_t, 2> conv_strides, const MultiIndex<2> conv_strides,
const Array<index_t, 2> conv_dilations, const MultiIndex<2> conv_dilations,
const Array<index_t, 2> in_left_pads, const MultiIndex<2> in_left_pads,
const Array<index_t, 2> in_right_pads) const const MultiIndex<2> in_right_pads) const
{ {
Run_1(p_wei_global, Run_1(p_wei_global,
p_in_global, p_in_global,
......
...@@ -279,7 +279,8 @@ struct DynamicMerge ...@@ -279,7 +279,8 @@ struct DynamicMerge
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{reverse_exclusive_scan_on_array( low_lengths_scan_{reverse_exclusive_scan_on_array(
low_lengths, math::multiplies<index_t>{}, index_t{1})}, low_lengths, math::multiplies<index_t>{}, index_t{1})},
up_lengths_{{reduce_on_array(low_lengths, math::multiplies<index_t>(), index_t{1})}} up_lengths_{make_multi_index(
reduce_on_array(low_lengths, math::multiplies<index_t>(), index_t{1}))}
{ {
static_assert(LowerIndex::Size() == NDimLow, "wrong!"); static_assert(LowerIndex::Size() == NDimLow, "wrong!");
} }
......
...@@ -178,10 +178,10 @@ struct DynamicTensorDescriptor_v2 ...@@ -178,10 +178,10 @@ struct DynamicTensorDescriptor_v2
index_t element_space_size) index_t element_space_size)
{ {
// zero initialization // zero initialization
HiddenIndex hidden_lengths{{0}}; HiddenIndex hidden_lengths = make_zero_multi_index<ndim_hidden_>();
// this is the orignal tensor element space size // this is the orignal tensor element space size
hidden_lengths(0) = element_space_size; hidden_lengths(Number<0>{}) = element_space_size;
// lengths for all other hidden dimensions // lengths for all other hidden dimensions
static_for<0, ntransform_, 1>{}([&transforms, &hidden_lengths](auto itran) { static_for<0, ntransform_, 1>{}([&transforms, &hidden_lengths](auto itran) {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
#if 1 // debug #if 1
template <index_t N> template <index_t N>
using MultiIndex = Array<index_t, N>; using MultiIndex = Array<index_t, N>;
......
...@@ -51,6 +51,13 @@ __host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs) ...@@ -51,6 +51,13 @@ __host__ __device__ constexpr auto make_array(const X& x, const Xs&... xs)
return Array<X, sizeof...(Xs) + 1>{{x, static_cast<X>(xs)...}}; return Array<X, sizeof...(Xs) + 1>{{x, static_cast<X>(xs)...}};
} }
// make empty array
template <typename X>
__host__ __device__ constexpr auto make_array()
{
return Array<X, 0>{};
}
template <typename TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto push_back(Array<TData, NSize>& a, const TData& x) __host__ __device__ constexpr auto push_back(Array<TData, NSize>& a, const TData& x)
{ {
......
...@@ -17,6 +17,8 @@ struct StaticallyIndexedArray<TData, 0> : Tuple<> ...@@ -17,6 +17,8 @@ struct StaticallyIndexedArray<TData, 0> : Tuple<>
{ {
using data_type = TData; using data_type = TData;
using base = Tuple<>; using base = Tuple<>;
__host__ __device__ explicit constexpr StaticallyIndexedArray() : base() {}
}; };
template <typename TData> template <typename TData>
...@@ -387,5 +389,12 @@ __host__ __device__ constexpr auto make_statically_indexed_array(const X& x, con ...@@ -387,5 +389,12 @@ __host__ __device__ constexpr auto make_statically_indexed_array(const X& x, con
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...); return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
} }
// make empty StaticallyIndexedArray
template <typename X>
__host__ __device__ constexpr auto make_statically_indexed_array()
{
return StaticallyIndexedArray<X, 0>();
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -101,10 +101,10 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -101,10 +101,10 @@ void device_dummy_dynamic_transform_v2(InDesc,
const decltype(wei_kcyx_desc), const decltype(wei_kcyx_desc),
const decltype(in_nchw_desc), const decltype(in_nchw_desc),
const decltype(out_nkhw_desc), const decltype(out_nkhw_desc),
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>>, const MultiIndex<2>>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -125,10 +125,10 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -125,10 +125,10 @@ void device_dummy_dynamic_transform_v2(InDesc,
float* const, float* const,
float* const, float* const,
const decltype(in_gemmk_gemmn_global_desc), const decltype(in_gemmk_gemmn_global_desc),
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>, const MultiIndex<2>,
const Array<index_t, 2>>, const MultiIndex<2>>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment