"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "6b242c29af317498ea51b465529cf8e68c2c88fd"
Commit 9d59a39a authored by Chao Liu's avatar Chao Liu
Browse files

refactoring

parent 33d1e0e2
......@@ -440,7 +440,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
wo_block_data_begin + wo_thread_data_begin),
make_zero_array<index_t, 10>(),
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
arithmetic_sequence_gen<0, 10, 1>::SeqType{},
arithmetic_sequence_gen<0, 10, 1>::type{},
Number<1>{});
#endif
});
......
......@@ -491,7 +491,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
wo_block_data_begin + wo_thread_data_begin),
make_zero_array<index_t, 10>(),
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
arithmetic_sequence_gen<0, 10, 1>::SeqType{},
arithmetic_sequence_gen<0, 10, 1>::type{},
Number<1>{});
#endif
});
......
......@@ -367,7 +367,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
......
......@@ -394,7 +394,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
......
......@@ -344,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
......
......@@ -398,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
......
......@@ -305,8 +305,9 @@ struct ConstantTensorDescriptor
{
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
decltype(GetStrides().Append(leaf_tensor::GetStrides()))>{};
return ConstantTensorDescriptor<decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
decltype(
GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
}
template <index_t IDim, index_t SliceLen>
......@@ -347,7 +348,7 @@ struct ConstantTensorDescriptor
// folded lengths
constexpr auto fold_lengths =
Sequence<unfold_length / fold_intervals_product>{}.Append(fold_intervals);
Sequence<unfold_length / fold_intervals_product>{}.PushBack(fold_intervals);
// folded strides
constexpr auto fold_strides =
......@@ -356,14 +357,14 @@ struct ConstantTensorDescriptor
fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::SeqType{};
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::type{};
constexpr auto new_lengths =
GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right));
GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right));
constexpr auto new_strides =
GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right));
GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
}
......@@ -377,11 +378,11 @@ struct ConstantTensorDescriptor
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::SeqType{};
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::type{};
// dimensions to be unfolded need to be continuous
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
......@@ -396,12 +397,12 @@ struct ConstantTensorDescriptor
constexpr auto new_lengths = GetLengths()
.Extract(left)
.PushBack(Number<unfold_length>{})
.Append(GetLengths().Extract(right));
.PushBack(GetLengths().Extract(right));
constexpr auto new_strides = GetStrides()
.Extract(left)
.PushBack(Number<unfold_stride>{})
.Append(GetStrides().Extract(right));
.PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
}
......
......@@ -87,7 +87,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template <class TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
constexpr auto zero_array = sequence2array(zero_sequence);
return zero_array;
}
......
......@@ -29,12 +29,9 @@ struct Sequence
}
template <index_t I>
__host__ __device__ constexpr index_t operator[](Number<I>) const
__host__ __device__ constexpr auto operator[](Number<I>) const
{
static_assert(I < mSize, "wrong! I too large");
const index_t mData[mSize + 1] = {Is..., 0};
return mData[I];
return Number<Get(Number<I>{})>{};
}
// make sure I is constepxr
......@@ -69,24 +66,30 @@ struct Sequence
return mData[mSize - 1];
}
template <index_t I>
__host__ __device__ static constexpr auto PushFront(Number<I>)
__host__ __device__ static constexpr auto PopFront();
__host__ __device__ static constexpr auto PopBack();
template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
{
return Sequence<I, Is...>{};
return Sequence<Xs..., Is...>{};
}
template <index_t I>
__host__ __device__ static constexpr auto PushBack(Number<I>)
template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Number<Xs>...)
{
return Sequence<Is..., I>{};
return Sequence<Xs..., Is...>{};
}
__host__ __device__ static constexpr auto PopFront();
__host__ __device__ static constexpr auto PopBack();
template <index_t... Xs>
__host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
{
return Sequence<Is..., Xs...>{};
}
template <index_t... Xs>
__host__ __device__ static constexpr auto Append(Sequence<Xs...>)
__host__ __device__ static constexpr auto PushBack(Number<Xs>...)
{
return Sequence<Is..., Xs...>{};
}
......@@ -105,6 +108,12 @@ struct Sequence
template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
template <class F>
__host__ __device__ static constexpr auto Transform(F f)
{
return Sequence<f(Is)...>{};
}
};
// merge sequence
......@@ -114,7 +123,7 @@ struct sequence_merge;
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
using SeqType = Sequence<Xs..., Ys...>;
using type = Sequence<Xs..., Ys...>;
};
// arithmetic sqeuence
......@@ -123,40 +132,29 @@ struct arithmetic_sequence_gen_impl
{
static constexpr index_t NSizeLeft = NSize / 2;
using SeqType = typename sequence_merge<
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
using type = typename sequence_merge<
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::type,
typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment,
NSize - NSizeLeft,
Increment>::SeqType>::SeqType;
Increment>::type>::type;
};
template <index_t IBegin, index_t Increment>
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment>
{
using SeqType = Sequence<IBegin>;
using type = Sequence<IBegin>;
};
template <index_t IBegin, index_t Increment>
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment>
{
using SeqType = Sequence<>;
using type = Sequence<>;
};
template <index_t IBegin, index_t IEnd, index_t Increment>
struct arithmetic_sequence_gen
{
using SeqType =
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
};
// transform sequence
template <class, class>
struct sequence_transform;
template <class F, index_t... Is>
struct sequence_transform<F, Sequence<Is...>>
{
using SeqType = Sequence<F{}(Is)...>;
using type = typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::type;
};
// uniform sequence
......@@ -168,9 +166,8 @@ struct uniform_sequence_gen
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using SeqType = typename sequence_transform<
return_constant,
typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType;
using type = decltype(
typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{}));
};
// reverse inclusive scan (with init) sequence
......@@ -180,34 +177,23 @@ struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{
using old_scan =
typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::SeqType;
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
template <index_t I, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using SeqType = Sequence<Reduce{}(I, Init)>;
using type = Sequence<Reduce{}(I, Init)>;
};
template <class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{
using SeqType = Sequence<>;
};
// extract sequence
template <class, class>
struct sequence_extract;
template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
using type = Sequence<>;
};
// split sequence
......@@ -216,11 +202,11 @@ struct sequence_split
{
static constexpr index_t NSize = Seq{}.GetSize();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::SeqType;
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
using SeqType0 = decltype(Seq::Extract(range0{}));
using SeqType1 = decltype(Seq::Extract(range1{}));
};
// reverse sequence
......@@ -230,31 +216,31 @@ struct sequence_reverse
static constexpr index_t NSize = Seq{}.GetSize();
using seq_split = sequence_split<Seq, NSize / 2>;
using SeqType = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::type,
typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
};
template <index_t I>
struct sequence_reverse<Sequence<I>>
{
using SeqType = Sequence<I>;
using type = Sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
using SeqType = Sequence<I1, I0>;
using type = Sequence<I1, I0>;
};
template <class Seq>
struct is_valid_sequence_map
{
static constexpr bool value = true;
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
};
......@@ -401,7 +387,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::SeqType{};
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
}
template <class Seq, class Reduce, index_t Init>
......@@ -425,7 +411,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack()
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
{
return typename sequence_reverse<Sequence<Is...>>::SeqType{};
return typename sequence_reverse<Sequence<Is...>>::type{};
}
template <index_t... Is>
......@@ -438,7 +424,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).Append(seq_right);
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <index_t... Xs>
......
......@@ -31,7 +31,7 @@ struct static_for
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::SeqType>{}(f);
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f);
}
};
......
......@@ -59,9 +59,9 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b)
}
template <class T>
__host__ __device__ constexpr T max(T x, T y)
__host__ __device__ constexpr T max(T x)
{
return x > y ? x : y;
return x;
}
template <class T, class... Ts>
......@@ -77,9 +77,9 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
}
template <class T>
__host__ __device__ constexpr T min(T x, T y)
__host__ __device__ constexpr T min(T x)
{
return x < y ? x : y;
return x;
}
template <class T, class... Ts>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment