Commit cb7ed650 authored by Chao Liu's avatar Chao Liu
Browse files

overhauling DynamicTensorDescriptor and dynamic multi-index transform in...

overhauling DynamicTensorDescriptor and dynamic multi-index transform in preparation for partially compile-time and partially run-time tensor descriptor
parent 3990522d
...@@ -139,6 +139,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -139,6 +139,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const index_t GemmM0 = GemmM / GemmM1; const index_t GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1; const index_t GemmN0 = GemmN / GemmN1;
#if 0
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
...@@ -146,6 +147,19 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -146,6 +147,19 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}), DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#else
const auto GemmM0_GemmM1 = make_tuple(GemmM0, Number<GemmM1>{});
const auto GemmN0_GemmN1 = make_tuple(GemmN0, Number<GemmN1>{});
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(
DynamicUnMerge<2, false, remove_cv_t<decltype(GemmM0_GemmM1)>>{GemmM0_GemmM1},
DynamicUnMerge<2, false, remove_cv_t<decltype(GemmN0_GemmN1)>>{GemmN0_GemmN1}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#endif
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_k_m_global_iterator_hacks =
......
...@@ -326,95 +326,6 @@ struct DynamicRightPad ...@@ -326,95 +326,6 @@ struct DynamicRightPad
} }
}; };
#if 0
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
template <index_t NDimUp>
struct DynamicEmbed
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<NDimUp>;
UpperIndex up_lengths_;
UpperIndex coefficients_;
__host__ __device__ constexpr DynamicEmbed() = default;
__host__ __device__ constexpr DynamicEmbed(const UpperIndex& up_lengths,
const UpperIndex& coefficients)
: up_lengths_{up_lengths}, coefficients_{coefficients}
{
static_assert(UpperIndex::Size() == NDimUp, "wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = 0;
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
});
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(Number<0>{}) = 0;
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicEmbed, ");
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("coefficients_ ");
print_multi_index(coefficients_);
printf("}");
}
};
#else
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings: // UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or // 1) Tuple of index_t, which is known at run-time, or
...@@ -510,7 +421,6 @@ struct DynamicEmbed ...@@ -510,7 +421,6 @@ struct DynamicEmbed
printf("}"); printf("}");
} }
}; };
#endif
template <index_t NDimLow> template <index_t NDimLow>
struct DynamicMerge struct DynamicMerge
...@@ -1020,115 +930,27 @@ struct DynamicMerge ...@@ -1020,115 +930,27 @@ struct DynamicMerge
} }
}; };
#if 0
template <index_t NDimUp, bool Use24BitIntegerCalculation = false>
struct DynamicUnMerge
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<NDimUp>;
UpperIndex up_lengths_;
UpperIndex up_lengths_scan_;
__host__ __device__ constexpr DynamicUnMerge() = default;
__host__ __device__ constexpr DynamicUnMerge(const UpperIndex& up_lengths)
: up_lengths_{up_lengths},
up_lengths_scan_{
container_reverse_exclusive_scan(up_lengths, math::multiplies<index_t>(), index_t{1})}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
if constexpr(!Use24BitIntegerCalculation)
{
idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}(
[&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
}
else
{
idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}([&](auto i) {
idx_low(Number<0>{}) =
(0x00ffffff & idx_low[Number<0>{}]) +
(0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
});
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
CalculateLowerIndex(idx_diff_low, idx_diff_up);
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicUnMerge, ");
print_multi_index(up_lengths_);
print_multi_index(up_lengths_scan_);
printf("}");
}
};
#else
template <index_t NDimUp, template <index_t NDimUp,
bool Use24BitIntegerCalculation = false, bool Use24BitIntegerCalculation = false,
typename UpLengths = MultiIndex<NDimUp>, typename UpLengths = MultiIndex<NDimUp>,
typename UpLengthsScan = MultiIndex<NDimUp>, typename std::enable_if<UpLengths::Size() == NDimUp, bool>::type = false>
typename std::enable_if<UpLengths::Size() == NDimUp && UpLengthsScan::Size() == NDimUp,
bool>::type = false>
struct DynamicUnMerge struct DynamicUnMerge
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<NDimUp>; using UpperIndex = MultiIndex<NDimUp>;
using UpLengthsScan =
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies_v2{}, Number<1>{}));
UpLengths up_lengths_; UpLengths up_lengths_;
UpLengthsScan up_lengths_scan_; UpLengthsScan up_lengths_scan_;
__host__ __device__ constexpr DynamicUnMerge() = default; __host__ __device__ constexpr DynamicUnMerge() = default;
__host__ __device__ constexpr DynamicUnMerge(const UpperIndex& up_lengths) __host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths)
: up_lengths_{up_lengths}, : up_lengths_{up_lengths},
up_lengths_scan_{ up_lengths_scan_{
container_reverse_exclusive_scan(up_lengths, math::multiplies<index_t>(), index_t{1})} container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})}
{ {
} }
...@@ -1142,7 +964,6 @@ struct DynamicUnMerge ...@@ -1142,7 +964,6 @@ struct DynamicUnMerge
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& idx_up) const
{ {
if constexpr(!Use24BitIntegerCalculation) if constexpr(!Use24BitIntegerCalculation)
{ {
idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}]; idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
...@@ -1201,7 +1022,6 @@ struct DynamicUnMerge ...@@ -1201,7 +1022,6 @@ struct DynamicUnMerge
printf("}"); printf("}");
} }
}; };
#endif
struct DynamicFreeze struct DynamicFreeze
{ {
......
...@@ -82,7 +82,7 @@ struct DynamicTensorDescriptor ...@@ -82,7 +82,7 @@ struct DynamicTensorDescriptor
__host__ __device__ constexpr DynamicTensorDescriptor() = default; __host__ __device__ constexpr DynamicTensorDescriptor() = default;
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms,
index_t element_space_size) ElementSpaceSize element_space_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size} element_space_size_{element_space_size}
...@@ -106,7 +106,7 @@ struct DynamicTensorDescriptor ...@@ -106,7 +106,7 @@ struct DynamicTensorDescriptor
{ {
static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range");
constexpr auto tmp = FindTransformAndItsUpperDimension(Number<IDim>{}); constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDim>{});
constexpr index_t itran = tmp[Number<0>{}]; constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}]; constexpr index_t idim_up = tmp[Number<1>{}];
...@@ -120,10 +120,7 @@ struct DynamicTensorDescriptor ...@@ -120,10 +120,7 @@ struct DynamicTensorDescriptor
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
__host__ __device__ constexpr auto GetElementSpaceSize() const __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
{
return element_space_size_;
}
template <typename Idx> template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
...@@ -155,7 +152,7 @@ struct DynamicTensorDescriptor ...@@ -155,7 +152,7 @@ struct DynamicTensorDescriptor
{ {
const auto lengths = generate_tuple( const auto lengths = generate_tuple(
[&](auto idim_visible) { [&](auto idim_visible) {
constexpr auto tmp = FindTransformAndItsUpperDimension(idim_visible); constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);
constexpr index_t itran = tmp[Number<0>{}]; constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}]; constexpr index_t idim_up = tmp[Number<1>{}];
...@@ -172,11 +169,11 @@ struct DynamicTensorDescriptor ...@@ -172,11 +169,11 @@ struct DynamicTensorDescriptor
Number<ndim_visible_>{}); Number<ndim_visible_>{});
// TODO: make container_reduce support tuple of Number and index_t // TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies<index_t>{}, index_t{1}); return container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto FindTransformAndItsUpperDimension(Number<IDim>) __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{ {
constexpr auto idim_visible = Number<IDim>{}; constexpr auto idim_visible = Number<IDim>{};
...@@ -552,12 +549,8 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te ...@@ -552,12 +549,8 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
const auto idx_up = const auto idx_up =
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
#if 0 // debug // Comment: using valid = valid && .. will result in weird control flow in ISA
// Comment: this implemenetation results in weird control flow in ISA
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
#else
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
#endif
} }
}); });
......
...@@ -39,7 +39,6 @@ template <index_t N> ...@@ -39,7 +39,6 @@ template <index_t N>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed(const MultiIndex<N>& lengths) make_dynamic_naive_tensor_descriptor_packed(const MultiIndex<N>& lengths)
{ {
const auto transforms = make_tuple(DynamicUnMerge<N>{lengths}); const auto transforms = make_tuple(DynamicUnMerge<N>{lengths});
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss = constexpr auto up_dim_hidden_idss =
...@@ -56,6 +55,39 @@ make_dynamic_naive_tensor_descriptor_packed(const MultiIndex<N>& lengths) ...@@ -56,6 +55,39 @@ make_dynamic_naive_tensor_descriptor_packed(const MultiIndex<N>& lengths)
transforms, element_space_size}; transforms, element_space_size};
} }
// Is... can be:
// 1) index_t, which is known at run-time
// 2) Number<>, which is known at compile-time
template <typename... Is>
__host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Is...>& lengths)
{
constexpr index_t N = sizeof...(Is);
using Lengths = remove_cv_t<remove_reference_t<decltype(lengths)>>;
const auto transforms = make_tuple(DynamicUnMerge<N, false, Lengths>{lengths});
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
const auto element_space_size = element_size;
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_size)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
template <index_t N> template <index_t N>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned(const MultiIndex<N>& lengths, index_t align) make_dynamic_naive_tensor_descriptor_aligned(const MultiIndex<N>& lengths, index_t align)
......
...@@ -389,14 +389,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -389,14 +389,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
#if 0 // debug
constexpr auto c_m0_m1_n0_n1_thread_desc = constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed<4>( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
make_multi_index(MRepeat, MPerThread, NRepeat, NPerThread)); Number<MPerThread>{},
#else Number<NRepeat>{},
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( Number<NPerThread>{}));
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>{});
#endif
// calculate origin of thread input tensor on global memory // calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
......
...@@ -80,6 +80,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -80,6 +80,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
// Comments: src_desc is constexpr
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -175,7 +178,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -175,7 +178,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// assume src_slice_origin_idx is 0 // assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx // TODO: support non-zero src_slice_oring_idx
constexpr index_t src_offset = constexpr index_t src_offset =
SrcDesc::CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); src_desc.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector(i) = p_src[Number<src_offset>{}]; dst_vector(i) = p_src[Number<src_offset>{}];
}); });
......
...@@ -24,15 +24,15 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize> ...@@ -24,15 +24,15 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
} }
template <typename... Ts, typename T> template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x) __host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
{ {
Tuple<Ts..., T> r; return container_cat(make_tuple(x), a);
}
static_for<0, sizeof...(Ts), 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r(Number<sizeof...(Ts)>{}) = x;
return r; template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
{
return container_cat(a, make_tuple(x));
} }
template <typename TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
...@@ -97,17 +97,29 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is.. ...@@ -97,17 +97,29 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is..
return container_reorder_give_new2old(old_seq, new2old); return container_reorder_give_new2old(old_seq, new2old);
} }
template <typename TData, typename Container, typename Reduce> template <typename Container, typename Reduce, typename Init>
__host__ __device__ constexpr TData container_reduce(const Container& a, Reduce f, TData init) __host__ __device__ constexpr auto container_reduce(const Container& x, Reduce reduce, Init init)
{ {
// static_assert(is_same<typename Arr::data_type, TData>::value, "wrong! different data type"); constexpr index_t NSize = Container::Size();
static_assert(Container::Size() > 0, "wrong");
// f is recursive function, fs is a dummy of f
TData result = init; // i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
static_for<0, Container::Size(), 1>{}([&](auto I) { result = f(result, a[I]); }); auto r_new = reduce(x[i], r_old);
return result; if constexpr(i.value > 0)
{
// recursively call f/fs
return fs(fs, i - Number<1>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, Number<NSize - 1>{}, init);
} }
template <typename TData, index_t NSize, typename Reduce> template <typename TData, index_t NSize, typename Reduce>
...@@ -147,30 +159,35 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i ...@@ -147,30 +159,35 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i
return y; return y;
} }
// Here should use StaticallyIndexedArray<TData, NSize>, instead of Tuple<Xs...>, template <typename... Xs, typename Reduce, typename Init>
// although the former is the alias of the latter. This is because compiler cannot
// infer the NSize if using StaticallyIndexedArray<TData, NSize>
// TODO: how to fix this?
template <typename... Xs, typename Reduce, typename TData>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init) container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
{ {
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> y; // f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
TData r = init; auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r; auto y_new = container_push_front(y_old, r_new);
r = f(r, x[i]);
}); if constexpr(i.value > 1)
{
y(Number<0>{}) = r; // recursively call f/fs
return fs(fs, i - Number<1>{}, y_new, r_new);
return y; }
else
{
return y_new;
}
};
// start recursion
return f(f, Number<NSize - 1>{}, make_tuple(init), init);
} }
// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData> template <typename... Xs, typename Reduce, typename TData>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init) container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
......
...@@ -33,6 +33,15 @@ struct multiplies ...@@ -33,6 +33,15 @@ struct multiplies
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
}; };
struct multiplies_v2
{
template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
};
template <class T> template <class T>
struct maxer struct maxer
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment