"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "407df37dfb109b8e2410e9afb0b72615aa4d3b07"
Commit 2a358d50 authored by Chao Liu's avatar Chao Liu
Browse files

refactor MultiIndex

parent 9535f806
......@@ -567,7 +567,7 @@ struct DummyDynamicTransform_v1
for(index_t iter = 0; iter < niter; ++iter)
{
constexpr auto gemmk1_gemmn0 = MultiIndex<2>{{1, 0}};
constexpr auto gemmk1_gemmn0 = make_multi_index(1, 0);
in_gemmk_gemmn_coord += gemmk1_gemmn0;
......
......@@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
MultiIndex<4>{{0, 0, b_block_data_on_global, 0}}, MultiIndex<4>{{0, 0, 0, 0}});
make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0));
// weight tensor
// global tensor in global memory, src of blockwise copy
......
......@@ -21,7 +21,7 @@ using MultiIndex = StaticallyIndexedArray<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs... xs)
{
return MultiIndex<sizeof...(Xs)>(static_cast<index_t>(xs)... r);
return make_statically_indexed_array<index_t>(xs...);
}
#endif
......
......@@ -17,11 +17,6 @@ struct StaticallyIndexedArray<TData, 0> : Tuple<>
{
using data_type = TData;
using base = Tuple<>;
template <typename... Ys>
__host__ __device__ explicit constexpr StaticallyIndexedArray(Ys&&... ys) : base(ys...)
{
}
};
template <typename TData>
......@@ -386,5 +381,11 @@ struct StaticallyIndexedArray<TData, 22> : Tuple<TData,
using data_type = TData;
};
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
} // namespace ck
#endif
......@@ -128,5 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
} // namespace ck
#endif
......@@ -5,12 +5,6 @@
namespace ck {
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
......
......@@ -52,11 +52,11 @@ void device_dummy_dynamic_transform_v1(InDesc,
const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{});
auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, MultiIndex<2>{{0, 0}});
make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, make_multi_index(0, 0));
for(index_t iter = 0; iter < 10; ++iter)
{
constexpr auto gemmk1_gemmn0 = MultiIndex<2>{{1, 0}};
constexpr auto gemmk1_gemmn0 = make_multi_index(1, 0);
printf("iter %d\n", iter);
......
......@@ -53,10 +53,10 @@ void device_dummy_dynamic_transform_v2(InDesc,
// test on cpu
{
auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{{0, 0}});
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, make_multi_index(0, 0));
const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2(
in_gemmk_gemmn_global_desc, MultiIndex<2>{{1, 0}});
in_gemmk_gemmn_global_desc, make_multi_index(1, 0));
for(index_t iter = 0; iter < 10; ++iter)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment