"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "2c6d63d0317d1a765b4e9f9b85177bb51a373b88"
Commit 2a358d50 authored by Chao Liu's avatar Chao Liu
Browse files

refactor MultiIndex

parent 9535f806
...@@ -567,7 +567,7 @@ struct DummyDynamicTransform_v1 ...@@ -567,7 +567,7 @@ struct DummyDynamicTransform_v1
for(index_t iter = 0; iter < niter; ++iter) for(index_t iter = 0; iter < niter; ++iter)
{ {
constexpr auto gemmk1_gemmn0 = MultiIndex<2>{{1, 0}}; constexpr auto gemmk1_gemmn0 = make_multi_index(1, 0);
in_gemmk_gemmn_coord += gemmk1_gemmn0; in_gemmk_gemmn_coord += gemmk1_gemmn0;
......
...@@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>(
MultiIndex<4>{{0, 0, b_block_data_on_global, 0}}, MultiIndex<4>{{0, 0, 0, 0}}); make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0));
// weight tensor // weight tensor
// global tensor in global memory, src of blockwise copy // global tensor in global memory, src of blockwise copy
......
...@@ -21,7 +21,7 @@ using MultiIndex = StaticallyIndexedArray<index_t, N>; ...@@ -21,7 +21,7 @@ using MultiIndex = StaticallyIndexedArray<index_t, N>;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs... xs) __host__ __device__ constexpr auto make_multi_index(Xs... xs)
{ {
return MultiIndex<sizeof...(Xs)>(static_cast<index_t>(xs)... r); return make_statically_indexed_array<index_t>(xs...);
} }
#endif #endif
......
...@@ -17,11 +17,6 @@ struct StaticallyIndexedArray<TData, 0> : Tuple<> ...@@ -17,11 +17,6 @@ struct StaticallyIndexedArray<TData, 0> : Tuple<>
{ {
using data_type = TData; using data_type = TData;
using base = Tuple<>; using base = Tuple<>;
template <typename... Ys>
__host__ __device__ explicit constexpr StaticallyIndexedArray(Ys&&... ys) : base(ys...)
{
}
}; };
template <typename TData> template <typename TData>
...@@ -386,5 +381,11 @@ struct StaticallyIndexedArray<TData, 22> : Tuple<TData, ...@@ -386,5 +381,11 @@ struct StaticallyIndexedArray<TData, 22> : Tuple<TData,
using data_type = TData; using data_type = TData;
}; };
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -128,5 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -128,5 +128,11 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
} }
}; };
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -5,12 +5,6 @@ ...@@ -5,12 +5,6 @@
namespace ck { namespace ck {
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
template <typename F, index_t N> template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>) __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{ {
......
...@@ -52,11 +52,11 @@ void device_dummy_dynamic_transform_v1(InDesc, ...@@ -52,11 +52,11 @@ void device_dummy_dynamic_transform_v1(InDesc,
const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{}); const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{});
auto in_gemmk_gemmn_coord = auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, MultiIndex<2>{{0, 0}}); make_dynamic_tensor_coordinate(in_gemmk_gemmn_global_desc, make_multi_index(0, 0));
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
constexpr auto gemmk1_gemmn0 = MultiIndex<2>{{1, 0}}; constexpr auto gemmk1_gemmn0 = make_multi_index(1, 0);
printf("iter %d\n", iter); printf("iter %d\n", iter);
......
...@@ -53,10 +53,10 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -53,10 +53,10 @@ void device_dummy_dynamic_transform_v2(InDesc,
// test on cpu // test on cpu
{ {
auto in_gemmk_gemmn_coord = auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{{0, 0}}); make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, make_multi_index(0, 0));
const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2( const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2(
in_gemmk_gemmn_global_desc, MultiIndex<2>{{1, 0}}); in_gemmk_gemmn_global_desc, make_multi_index(1, 0));
for(index_t iter = 0; iter < 10; ++iter) for(index_t iter = 0; iter < 10; ++iter)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment