"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "9a833e2c455b7cb19dbed7fd2b527270da82e3c2"
Commit 709f13a6 authored by Chao Liu's avatar Chao Liu
Browse files

use more constexpr

parent 498e71b0
...@@ -443,7 +443,7 @@ int main(int argc, char* argv[]) ...@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -455,7 +455,7 @@ int main(int argc, char* argv[]) ...@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -549,12 +549,24 @@ int main(int argc, char* argv[]) ...@@ -549,12 +549,24 @@ int main(int argc, char* argv[])
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 7x7 image
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
// 1x1 filter, 73x73 image // 1x1 filter, 73x73 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 512;
constexpr index_t HI = 73; constexpr index_t HI = 73;
constexpr index_t WI = 73; constexpr index_t WI = 73;
constexpr index_t K = 128; constexpr index_t K = 128;
......
#pragma once #pragma once
#include "Sequence.hip.hpp" #include "Sequence.hip.hpp"
#include "functional.hip.hpp" #include "functional2.hip.hpp"
template <class TData, index_t NSize> template <class TData, index_t NSize>
struct Array struct Array
...@@ -25,14 +25,17 @@ struct Array ...@@ -25,14 +25,17 @@ struct Array
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const __host__ __device__ constexpr TData Get(Number<I>) const
{ {
static_assert(I < NSize, "wrong!");
return mData[I]; return mData[I];
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr bool Set(Number<I>, TData x) __host__ __device__ constexpr void Set(Number<I>, TData x)
{ {
static_assert(I < NSize, "wrong!");
mData[I] = x; mData[I] = x;
return true; // for constexpr
} }
__host__ __device__ constexpr auto PushBack(TData x) const __host__ __device__ constexpr auto PushBack(TData x) const
...@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>) ...@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template <class TData, index_t NSize> template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array() __host__ __device__ constexpr auto make_zero_array()
{ {
#if 0
Array<TData, NSize> a; Array<TData, NSize> a;
static_for<0, NSize, 1>{}([&](auto I) { static_for<0, NSize, 1>{}([&](auto I) {
...@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array() ...@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array()
}); });
return a; return a;
#else
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
constexpr auto zero_array = sequence2array(zero_sequence);
return zero_array;
#endif
} }
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
...@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData ...@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return new_array; return new_array;
} }
#if 0
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new) Sequence<IRs...> old2new)
...@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData ...@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array; return new_array;
} }
#else
template <class TData, index_t NSize, class MapOld2New>
struct reorder_array_given_old2new_impl
{
const Array<TData, NSize>& old_array_ref;
Array<TData, NSize>& new_array_ref;
__host__
__device__ constexpr reorder_array_given_old2new_impl(const Array<TData, NSize>& old_array,
Array<TData, NSize>& new_array)
: old_array_ref(old_array), new_array_ref(new_array)
{
}
template <index_t IOldDim>
__host__ __device__ constexpr void operator()(Number<IOldDim>) const
{
TData old_data = old_array_ref.Get(Number<IOldDim>{});
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
new_array_ref.Set(Number<INewDim>{}, old_data);
}
};
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
{
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}(
reorder_array_given_old2new_impl<TData, NSize, Sequence<IRs...>>(old_array, new_array));
return new_array;
}
#endif
template <class TData, index_t NSize, class ExtractSeq> template <class TData, index_t NSize, class ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq) __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
......
...@@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor ...@@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t I> template <index_t I>
constexpr __host__ __device__ bool operator()(Number<I>) const __host__ __device__ constexpr void operator()(Number<I>) const
{ {
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{}); constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
index_t itmp = original_multi_id_partial_ref.Get(Number<I>{}); index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});
original_multi_id_ref.Set(Number<idim_original>{}, itmp); original_multi_id_ref.Set(Number<idim_original>{}, itmp);
return true;
} }
}; };
...@@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor ...@@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
constexpr __host__ __device__ bool operator()(Number<IDim>) const __host__ __device__ constexpr void operator()(Number<IDim>) const
{ {
constexpr auto original_dims_partial = constexpr auto original_dims_partial =
std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{}); std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});
...@@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor ...@@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor
static_for<0, original_dims_partial.GetSize(), 1>{}( static_for<0, original_dims_partial.GetSize(), 1>{}(
GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>( GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
original_multi_id_partial, original_multi_id_ref)); original_multi_id_partial, original_multi_id_ref));
return true;
} }
}; };
// return type is Array<...>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id) GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{ {
...@@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor ...@@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor
} }
#endif #endif
#if 0
// return type is Sequence<...>
template <index_t... Is>
__host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
{
// not implemented
return Sequence<>{};
}
#endif
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id) GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{ {
......
...@@ -37,10 +37,11 @@ struct Sequence ...@@ -37,10 +37,11 @@ struct Sequence
template <class MapOld2New> template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
{ {
#if 0
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType, static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
"wrong! invalid old2new map"); "wrong! invalid old2new map");
#endif
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{}; constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
return ReorderGivenNew2Old(map_new2old); return ReorderGivenNew2Old(map_new2old);
...@@ -99,6 +100,7 @@ struct Sequence ...@@ -99,6 +100,7 @@ struct Sequence
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>); __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
}; };
// merge sequence
template <class, class> template <class, class>
struct sequence_merge; struct sequence_merge;
...@@ -108,6 +110,7 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> ...@@ -108,6 +110,7 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using SeqType = Sequence<Xs..., Ys...>; using SeqType = Sequence<Xs..., Ys...>;
}; };
// arithmetic sqeuence
template <index_t IBegin, index_t NSize, index_t Increment> template <index_t IBegin, index_t NSize, index_t Increment>
struct arithmetic_sequence_gen_impl struct arithmetic_sequence_gen_impl
{ {
...@@ -139,7 +142,31 @@ struct arithmetic_sequence_gen ...@@ -139,7 +142,31 @@ struct arithmetic_sequence_gen
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType; typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
}; };
// reverse scan with init // transform sequence
template <class, class>
struct sequence_transform;
template <class F, index_t... Is>
struct sequence_transform<F, Sequence<Is...>>
{
using SeqType = Sequence<F{}(Is)...>;
};
// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct return_constant
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using SeqType = typename sequence_transform<
return_constant,
typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType;
};
// reverse inclusive scan (with init) sequence
template <class, class, index_t> template <class, class, index_t>
struct sequence_reverse_inclusive_scan; struct sequence_reverse_inclusive_scan;
...@@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init> ...@@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
using SeqType = Sequence<>; using SeqType = Sequence<>;
}; };
#if 0 // extract sequence
// reverse scan with token
template <class, class, index_t>
struct sequence_reverse_inclusive_token_scan;
template <index_t I, index_t... Is, class F, index_t Token>
struct sequence_reverse_inclusive_token_scan<Sequence<I, Is...>, F, Token>
{
using old_scan = typename sequence_reverse_inclusive_token_scan<Sequence<Is...>, F, Token>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};
#endif
template <class, class> template <class, class>
struct sequence_extract; struct sequence_extract;
...@@ -191,6 +203,7 @@ struct sequence_extract<Seq, Sequence<Is...>> ...@@ -191,6 +203,7 @@ struct sequence_extract<Seq, Sequence<Is...>>
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>; using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
}; };
// split sequence
template <class Seq, index_t I> template <class Seq, index_t I>
struct sequence_split struct sequence_split
{ {
...@@ -203,6 +216,7 @@ struct sequence_split ...@@ -203,6 +216,7 @@ struct sequence_split
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType; using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
}; };
// reverse sequence
template <class Seq> template <class Seq>
struct sequence_reverse struct sequence_reverse
{ {
...@@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys. ...@@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys.
{ {
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
#if 0
static_for<0, seq_x.GetSize(), 1>{}( static_for<0, seq_x.GetSize(), 1>{}(
[&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
#endif
return Sequence<(Xs - Ys)...>{}; return Sequence<(Xs - Ys)...>{};
} }
...@@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>) ...@@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
{ {
constexpr auto seq_x = Sequence<Xs...>{}; constexpr auto seq_x = Sequence<Xs...>{};
#if 0
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){}; constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
}); });
#endif
return Sequence<(Y - Xs)...>{}; return Sequence<(Y - Xs)...>{};
} }
......
...@@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) { static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto clipboard_data_multi_id_begin = const auto clipboard_data_multi_id_begin =
...@@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex( const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
dst_data_multi_id_begin); // cannot not constexpr, why? dst_data_multi_id_begin); // cannot not constexpr, why?
#else
constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id_ * thread_sub_tensor_lengths;
constexpr auto dst_data_multi_id_begin = repeat_multi_id_ * data_per_cluster_per_dims;
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
constexpr index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
p_clipboard + clipboard_offset, p_clipboard + clipboard_offset,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "Array.hip.hpp" #include "Array.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
#include "functional2.hip.hpp" #include "functional2.hip.hpp"
#include "functional3.hip.hpp"
#if USE_AMD_INLINE_ASM #if USE_AMD_INLINE_ASM
#include "amd_inline_asm.hip.hpp" #include "amd_inline_asm.hip.hpp"
......
#pragma once #pragma once
#include "integral_constant.hip.hpp" #include "integral_constant.hip.hpp"
#include "Sequence.hip.hpp"
struct forwarder struct forwarder
{ {
...@@ -10,6 +11,14 @@ struct forwarder ...@@ -10,6 +11,14 @@ struct forwarder
} }
}; };
struct swallow
{
template <class... Ts>
__host__ __device__ constexpr swallow(Ts&&... ts)
{
}
};
#if 0 #if 0
template<class F> template<class F>
__host__ __device__ constexpr auto unpacker(F f) __host__ __device__ constexpr auto unpacker(F f)
...@@ -72,51 +81,6 @@ struct static_if<false> ...@@ -72,51 +81,6 @@ struct static_if<false>
return Type{}; return Type{};
} }
}; };
template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl
{
template <class F>
constexpr __host__ __device__ void operator()(F f) const
{
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range");
f(Number<Iter>{});
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
}
};
template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment>
{
template <class F>
constexpr __host__ __device__ void operator()(F) const
{
// no work left, just return
return;
}
};
// F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
template <class F>
constexpr __host__ __device__ void operator()(F f) const
{
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
#if 0
static_if<(NBegin < NEnd)>{}(
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
#else
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
#endif
}
};
template <index_t NLoop> template <index_t NLoop>
struct static_const_reduce_n struct static_const_reduce_n
......
#pragma once #pragma once
#include "functional.hip.hpp"
#include "Sequence.hip.hpp" #include "Sequence.hip.hpp"
// RemainLengths: Sequence<...> #if 0
template <class RemainLengths> template <index_t Iter, index_t Remaining, index_t Increment>
struct static_ford_impl struct static_for_impl
{ {
// F signature: F(Sequence<...> multi_id) template <class F>
// CurrentMultiIndex: Sequence<...> constexpr __host__ __device__ void operator()(F f) const
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{ {
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range");
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
CurrentMultiIndex::PushBack(I));
});
}
};
template <> f(Number<Iter>{});
struct static_ford_impl<Sequence<>> static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{
f(CurrentMultiIndex{});
} }
}; };
// Lengths is Sequence<...> template <index_t Iter, index_t Increment>
template <class Lengths> struct static_for_impl<Iter, 0, Increment>
struct static_ford
{ {
// F signature: F(Sequence<...> multi_id)
template <class F> template <class F>
__host__ __device__ void operator()(F f) const constexpr __host__ __device__ void operator()(F) const
{ {
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); // no work left, just return
return;
static_ford_impl<Lengths>{}(f, Sequence<>{});
} }
}; };
template <index_t RemainDim> // F signature: F(Number<Iter>)
struct ford_impl template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{ {
// F signature: F(Array<...> multi_id) template <class F>
// CurrentMultiIndex: Array<...> constexpr __host__ __device__ void operator()(F f) const
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{ {
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert(RemainDim > 1, "wrong!");
constexpr auto next_length = RemainLengths{}.Front(); static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
for(index_t i = 0; i < next_length; ++i) #if 0
{ static_if<(NBegin < NEnd)>{}(
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront()); [&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
} #else
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
#endif
} }
}; };
#else
template <class>
struct static_for_impl;
template <> template <index_t... Is>
struct ford_impl<1> struct static_for_impl<Sequence<Is...>>
{ {
// F signature: F(Array<...> multi_id) template <class F>
// CurrentMultiIndex: Array<...> __host__ __device__ constexpr void operator()(F f) const
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{ {
static_assert(RemainLengths::GetSize() == 1, "wrong!"); swallow{(f(Number<Is>{}), 0)...};
constexpr index_t last_length = RemainLengths{}.Front();
for(index_t i = 0; i < last_length; ++i)
{
f(current_multi_id.PushBack(i));
}
} }
}; };
// Lengths is Sequence<...> // F signature: F(Number<Iter>)
template <class Lengths> template <index_t NBegin, index_t NEnd, index_t Increment>
struct ford struct static_for
{ {
// F signature: F(Array<...> multi_id)
template <class F> template <class F>
__host__ __device__ void operator()(F f) const __host__ __device__ constexpr void operator()(F f) const
{ {
constexpr index_t first_length = Lengths{}.Front(); static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
for(index_t i = 0; i < first_length; ++i) static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::SeqType>{}(f);
{
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
}
} }
}; };
#endif
#pragma once
#include "functional.hip.hpp"
#include "functional2.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
// RemainLengths: Sequence<...>
template <class RemainLengths>
struct static_ford_impl
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
CurrentMultiIndex::PushBack(I));
});
}
};
template <>
struct static_ford_impl<Sequence<>>
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
{
f(CurrentMultiIndex{});
}
};
// Lengths is Sequence<...>
template <class Lengths>
struct static_ford
{
// F signature: F(Sequence<...> multi_id)
template <class F>
__host__ __device__ void operator()(F f) const
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_ford_impl<Lengths>{}(f, Sequence<>{});
}
};
template <index_t RemainDim>
struct ford_impl
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
static_assert(RemainDim > 1, "wrong!");
constexpr auto next_length = RemainLengths{}.Front();
for(index_t i = 0; i < next_length; ++i)
{
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
}
}
};
template <>
struct ford_impl<1>
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{
static_assert(RemainLengths::GetSize() == 1, "wrong!");
constexpr index_t last_length = RemainLengths{}.Front();
for(index_t i = 0; i < last_length; ++i)
{
f(current_multi_id.PushBack(i));
}
}
};
// Lengths is Sequence<...>
template <class Lengths>
struct ford
{
// F signature: F(Array<...> multi_id)
template <class F>
__host__ __device__ void operator()(F f) const
{
constexpr index_t first_length = Lengths{}.Front();
for(index_t i = 0; i < first_length; ++i)
{
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
}
}
};
...@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) { const auto run_blockwise_gemm = [&](auto... Xs) {
#if 0 #if 1
return blockwise_gemm.Run(Xs...); return blockwise_gemm.Run(Xs...);
#else #else
return blockwise_gemm.Run_asm(Xs...); return blockwise_gemm.Run_asm(Xs...);
......
...@@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
}); });
#else #else
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id_) { static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
const auto access_multi_id = sequence2array(access_multi_id_); constexpr index_t itmp = access_multi_id.Back() * DataPerAccess;
auto data_multi_id_in_access_order = access_multi_id; constexpr auto data_multi_id_in_access_order =
data_multi_id_in_access_order[nDim - 1] = access_multi_id[nDim - 1] * DataPerAccess; access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
const auto data_multi_id = constexpr auto data_multi_id = reorder_array_given_old2new(
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); sequence2array(data_multi_id_in_access_order), DimAccessOrder{});
const index_t src_index = const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment