Commit cab8f2e5 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parents c20aabc3 9a17e7fb
...@@ -101,6 +101,9 @@ template <typename InDataType, ...@@ -101,6 +101,9 @@ template <typename InDataType,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_threadwise struct GridwiseReduction_mk_to_m_threadwise
{ {
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
template <typename T> template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>; using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
...@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType, AccDataType,
AccDataType, InGridDesc_M_K,
InGridDesc_M_K, decltype(thread_buffer_desc),
decltype(thread_buffer_desc), ThreadBufferLengths,
ThreadBufferLengths, ThreadBufferDimAccessOrder,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, InSrcVectorDim,
InSrcVectorDim, InSrcVectorSize,
InSrcVectorSize, 1,
1, false>(
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
...@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType, AccDataType,
AccDataType, InGridDesc_M_K,
InGridDesc_M_K, decltype(thread_buffer_desc),
decltype(thread_buffer_desc), ThreadBufferLengths,
ThreadBufferLengths, ThreadBufferDimAccessOrder,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, InSrcVectorDim,
InSrcVectorDim, InSrcVectorSize,
InSrcVectorSize, 1,
1, false>(
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck { namespace ck {
...@@ -85,16 +86,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -85,16 +86,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcSliceOriginIdx, template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
typename SrcBuffer,
typename DstBuffer,
typename DstStepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf)
const DstStepHacks& dst_step_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -108,9 +105,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -108,9 +105,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr auto src_desc = remove_cvref_t<SrcDesc>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -119,85 +113,26 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -119,85 +113,26 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr auto dst_scalar_step_in_vector = constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{}); generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(dst_scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
index_t tmp = ordered_access_idx[I0]; static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<1, i, 1>{}([&](auto j) { constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0; static_for<0, num_access, 1>{}([&](auto idx_1d) {
}); constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access;
}();
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData dst_v; SrcData dst_v;
...@@ -212,69 +147,18 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -212,69 +147,18 @@ struct ThreadwiseTensorSliceTransfer_v1r3
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) dst_buf.template Update<DstInMemOp, dst_vector_t>(
{ dst_coord_.GetOffset(),
dst_buf.template Set<dst_vector_t>( is_dst_valid,
dst_coord_.GetOffset(), dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
{
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
});
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
constexpr auto move_on_dim = [&]() constexpr if constexpr(idx_1d.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim_; move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
}
}
});
}); });
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
...@@ -287,82 +171,27 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -287,82 +171,27 @@ struct ThreadwiseTensorSliceTransfer_v1r3
} }
} }
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(dst_scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr auto forward_sweep = [&]() { if constexpr(num_access == 0)
StaticallyIndexedArray<bool, nDim> forward_sweep_; {
return typename SpaceFillingCurve::Index{};
forward_sweep_(I0) = true; }
else
static_for<1, nDim, 1>{}([&](auto i) { {
index_t tmp = ordered_access_lengths[I0] - 1; constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step; return reset_step;
}
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -383,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -383,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
const DstElementwiseOperation dst_element_op_; const DstElementwiseOperation dst_element_op_;
}; // namespace ck }; // namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume: // Assume:
// 1. src: // 1. src:
...@@ -428,16 +257,12 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -428,16 +257,12 @@ struct ThreadwiseTensorSliceTransfer_v2
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename SrcBuffer, template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
typename DstBuffer,
typename DstSliceOriginIdx,
typename SrcStepHacks>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf, DstBuffer& dst_buf)
const SrcStepHacks& src_step_hacks)
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
...@@ -453,9 +278,6 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -453,9 +278,6 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr auto dst_desc = remove_cvref_t<DstDesc>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -464,80 +286,19 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -464,80 +286,19 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr auto src_scalar_step_in_vector = constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{}); generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
src_desc, forward_step_idx, src_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
src_desc, backward_step_idx, src_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
}();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type; typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
...@@ -555,38 +316,13 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -555,38 +316,13 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i]; dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i];
}); });
constexpr auto move_on_dim = [&]() constexpr if constexpr(idx_1d.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim_; move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
} }
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
}
}
});
}); });
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
...@@ -599,82 +335,27 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -599,82 +335,27 @@ struct ThreadwiseTensorSliceTransfer_v2
} }
} }
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true; constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
static_for<1, nDim, 1>{}([&](auto i) { {
index_t tmp = ordered_access_lengths[I0] - 1; return typename SpaceFillingCurve::Index{};
}
static_for<1, i, 1>{}([&](auto j) { else
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; {
}); constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step; return reset_step;
}
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck { namespace ck {
...@@ -40,9 +41,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -40,9 +41,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
...@@ -79,70 +77,14 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -79,70 +77,14 @@ struct ThreadwiseTensorSliceTransfer_v6r1
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
auto make_forward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, forward_step_idx);
},
Number<nDim>{});
};
auto make_backward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, backward_step_idx);
},
Number<nDim>{});
};
// make forward steps
const auto src_forward_steps = make_forward_steps(src_desc);
const auto dst_forward_steps = make_forward_steps(dst_desc);
// make backward steps
const auto src_backward_steps = make_backward_steps(src_desc);
const auto dst_backward_steps = make_backward_steps(dst_desc);
// loop over slice window
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true; // loop over space-filling curve
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -168,59 +110,20 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -168,59 +110,20 @@ struct ThreadwiseTensorSliceTransfer_v6r1
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) dst_buf.template Update<DstInMemOp, dst_vector_t>(
{ dst_coord_.GetOffset(),
dst_buf.template Set<dst_vector_t>( is_dst_valid,
dst_coord_.GetOffset(), dst_vector_container.template AsType<dst_vector_t>()[I0]);
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
constexpr auto move_on_dim = [&]() constexpr // move coordinate
if constexpr(idx_1d.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
static_for<0, nDim, 1>{}([&](auto i) { src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim_;
} }
();
// move coordinate
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
}
}
});
}); });
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
...@@ -243,59 +146,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -243,59 +146,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1
__device__ static constexpr auto GetCoordinateResetStep() __device__ static constexpr auto GetCoordinateResetStep()
{ {
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr auto data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
scalar_per_access;
}();
// constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr auto reset_data_step = [&]() { if constexpr(num_access == 0)
Index reset_data_step_; {
return typename SpaceFillingCurve::Index{};
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); }
else
return reset_data_step_; {
}(); constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_data_step; return reset_step;
}
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -332,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -332,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
const ElementwiseOperation element_op_; const ElementwiseOperation element_op_;
}; }; // namespace ck
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck { namespace ck {
...@@ -44,10 +45,6 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -44,10 +45,6 @@ struct ThreadwiseTensorSliceTransfer_v6r2
using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{}));
using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
...@@ -96,72 +93,14 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -96,72 +93,14 @@ struct ThreadwiseTensorSliceTransfer_v6r2
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
auto make_forward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, forward_step_idx);
},
Number<nDim>{});
};
auto make_backward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, backward_step_idx);
},
Number<nDim>{});
};
// make forward steps
const auto src0_forward_steps = make_forward_steps(src0_desc);
const auto src1_forward_steps = make_forward_steps(src1_desc);
const auto dst_forward_steps = make_forward_steps(dst_desc);
// make backward steps
const auto src0_backward_steps = make_backward_steps(src0_desc);
const auto src1_backward_steps = make_backward_steps(src1_desc);
const auto dst_backward_steps = make_backward_steps(dst_desc);
// loop over slice window
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true; constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>; using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>;
using src0_vector_t = typename src0_vector_type::type; using src0_vector_t = typename src0_vector_type::type;
...@@ -197,65 +136,22 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -197,65 +136,22 @@ struct ThreadwiseTensorSliceTransfer_v6r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) dst_buf.template Update<DstInMemOp, dst_vector_t>(
{ dst_coord_.GetOffset(),
dst_buf.template Set<dst_vector_t>( is_dst_valid,
dst_coord_.GetOffset(), dst_vector_container.template AsType<dst_vector_t>()[I0]);
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
constexpr auto move_on_dim = [&]() constexpr // move coordinate
if constexpr(idx_1d.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
static_for<0, nDim, 1>{}([&](auto i) { src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step));
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; move_tensor_coordinate(
src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step));
static_for<i + 1, nDim, 1>{}([&](auto j) { move_tensor_coordinate(
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
});
});
return move_on_dim_;
} }
();
// move coordinate
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
}
}
});
}); });
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
...@@ -286,59 +182,25 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -286,59 +182,25 @@ struct ThreadwiseTensorSliceTransfer_v6r2
__device__ static constexpr auto GetCoordinateResetStep() __device__ static constexpr auto GetCoordinateResetStep()
{ {
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr auto data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
scalar_per_access;
}();
// constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr auto reset_data_step = [&]() { if constexpr(num_access == 0)
Index reset_data_step_; {
return typename SpaceFillingCurve::Index{};
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); }
else
return reset_data_step_; {
}(); constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_data_step; return reset_step;
}
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace ck { namespace ck {
...@@ -48,11 +49,6 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -48,11 +49,6 @@ struct ThreadwiseTensorSliceTransfer_v6r3
using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{}));
using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{}));
using Src2CoordStep = decltype(make_tensor_coordinate_step(Src2Desc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
...@@ -112,74 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -112,74 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v6r3
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
auto make_forward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, forward_step_idx);
},
Number<nDim>{});
};
auto make_backward_steps = [&](auto desc) {
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, backward_step_idx);
},
Number<nDim>{});
};
// make forward steps
const auto src0_forward_steps = make_forward_steps(src0_desc);
const auto src1_forward_steps = make_forward_steps(src1_desc);
const auto src2_forward_steps = make_forward_steps(src2_desc);
const auto dst_forward_steps = make_forward_steps(dst_desc);
// make backward steps
const auto src0_backward_steps = make_backward_steps(src0_desc);
const auto src1_backward_steps = make_backward_steps(src1_desc);
const auto src2_backward_steps = make_backward_steps(src2_desc);
const auto dst_backward_steps = make_backward_steps(dst_desc);
// loop over slice window
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true; constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>; using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>;
using src0_vector_t = typename src0_vector_type::type; using src0_vector_t = typename src0_vector_type::type;
...@@ -224,72 +160,24 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -224,72 +160,24 @@ struct ThreadwiseTensorSliceTransfer_v6r3
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf dst_buf.template Update<DstInMemOp, dst_vector_t>(
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) dst_coord_.GetOffset(),
{ is_dst_valid,
dst_buf.template Set<dst_vector_t>( dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
}
constexpr auto move_on_dim = [&]() constexpr // move coordinate
if constexpr(idx_1d.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
static_for<0, nDim, 1>{}([&](auto i) { src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step));
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; move_tensor_coordinate(
src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step));
static_for<i + 1, nDim, 1>{}([&](auto j) { move_tensor_coordinate(
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step));
}); move_tensor_coordinate(
}); dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
return move_on_dim_;
} }
();
// move coordinate
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src2_desc, src2_coord_, src2_forward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
src2_desc, src2_coord_, src2_backward_steps[dim_access_order[i]]);
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
}
}
});
}); });
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
...@@ -328,59 +216,25 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -328,59 +216,25 @@ struct ThreadwiseTensorSliceTransfer_v6r3
__device__ static constexpr auto GetCoordinateResetStep() __device__ static constexpr auto GetCoordinateResetStep()
{ {
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
constexpr auto dim_access_order = DimAccessOrder{}; remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr auto data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
scalar_per_access;
}();
// constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr auto reset_data_step = [&]() { if constexpr(num_access == 0)
Index reset_data_step_; {
return typename SpaceFillingCurve::Index{};
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); }
else
return reset_data_step_; {
}(); constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_data_step; return reset_step;
}
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
......
...@@ -57,7 +57,7 @@ template <typename InDataType, ...@@ -57,7 +57,7 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
typename ReduceDims, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt, NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices_t IndicesOpt>
...@@ -91,7 +91,7 @@ void add_device_reduce_instance_blockwise( ...@@ -91,7 +91,7 @@ void add_device_reduce_instance_blockwise(
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
ReduceDims, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -112,34 +112,36 @@ void add_device_reduce_instance_blockwise( ...@@ -112,34 +112,36 @@ void add_device_reduce_instance_blockwise(
}); });
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
template void add_device_reduce_instance_blockwise<inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ template void add_device_reduce_instance_blockwise<inT, \
outT, \ compT, \
Rank, \ outT, \
Sequence<__VA_ARGS__>, \ Rank, \
ReduceOpId, \ NumReduceDim, \
NanOpt, \ ReduceOpId, \
IndicesOpt>( \ NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_BY_ID( \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
outT, \ compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \ outT, \
static_cast<NanPropagation_t>(NanOpt), \ static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<NanPropagation_t>(NanOpt), \
Rank, \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
__VA_ARGS__) Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
...@@ -149,15 +151,16 @@ void add_device_reduce_instance_blockwise( ...@@ -149,15 +151,16 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
outT, \ compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \ outT, \
static_cast<NanPropagation_t>(NanOpt), \ static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<NanPropagation_t>(NanOpt), \
Rank, \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
__VA_ARGS__) Rank, \
NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
} // namespace device } // namespace device
......
...@@ -11,25 +11,25 @@ namespace device { ...@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,16 +11,16 @@ namespace device { ...@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,34 +11,34 @@ namespace device { ...@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0); ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,16 +11,16 @@ namespace device { ...@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0); ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,34 +11,34 @@ namespace device { ...@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -45,7 +45,7 @@ template <typename InDataType, ...@@ -45,7 +45,7 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
typename ReduceDims, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt, NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices_t IndicesOpt>
...@@ -86,7 +86,7 @@ void add_device_reduce_instance_blockwise_second_call( ...@@ -86,7 +86,7 @@ void add_device_reduce_instance_blockwise_second_call(
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
ReduceDims, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -106,21 +106,21 @@ void add_device_reduce_instance_blockwise_second_call( ...@@ -106,21 +106,21 @@ void add_device_reduce_instance_blockwise_second_call(
}); });
}; };
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \ #define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise_second_call<inT, \ template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \ std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \ #define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \ ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
...@@ -128,27 +128,27 @@ void add_device_reduce_instance_blockwise_second_call( ...@@ -128,27 +128,27 @@ void add_device_reduce_instance_blockwise_second_call(
static_cast<NanPropagation_t>(NanOpt), \ static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \ Rank, \
__VA_ARGS__) NumReduceDim)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise_second_call<inT, \ extern template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
std::vector< \ std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \ DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
InElementwiseOperation, \ InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \ typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \ ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
...@@ -156,7 +156,7 @@ void add_device_reduce_instance_blockwise_second_call( ...@@ -156,7 +156,7 @@ void add_device_reduce_instance_blockwise_second_call(
static_cast<NanPropagation_t>(NanOpt), \ static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \ Rank, \
__VA_ARGS__) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
} // namespace device } // namespace device
......
...@@ -11,25 +11,25 @@ namespace device { ...@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,16 +11,16 @@ namespace device { ...@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,34 +11,34 @@ namespace device { ...@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,16 +11,16 @@ namespace device { ...@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,34 +11,34 @@ namespace device { ...@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -59,7 +59,7 @@ template <typename InDataType, ...@@ -59,7 +59,7 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
typename ReduceDims, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt, NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices_t IndicesOpt>
...@@ -110,7 +110,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -110,7 +110,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
ReduceDims, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -132,21 +132,21 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -132,21 +132,21 @@ void add_device_reduce_instance_multiblock_atomic_add(
} }
}; };
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_multiblock_atomic_add<inT, \ template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances) device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
...@@ -154,15 +154,15 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -154,15 +154,15 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_cast<NanPropagation_t>(NanOpt), \ static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \ Rank, \
__VA_ARGS__) NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \ extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
...@@ -173,7 +173,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -173,7 +173,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
device_op_instances) device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
...@@ -181,7 +181,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -181,7 +181,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_cast<NanPropagation_t>(NanOpt), \ static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \ Rank, \
__VA_ARGS__) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
} // namespace device } // namespace device
......
...@@ -11,13 +11,13 @@ namespace device { ...@@ -11,13 +11,13 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); // ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); // ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,13 +11,13 @@ namespace device { ...@@ -11,13 +11,13 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); // ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); // ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment