Commit 22ee67a9 authored by root's avatar root
Browse files

add reduce_threadwise_multi_d

parent b17ce193
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror # -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -235,8 +235,8 @@ int main(int argc, char* argv[]) ...@@ -235,8 +235,8 @@ int main(int argc, char* argv[])
else else
{ {
// for testing half_t // for testing half_t
pass = pass = pass &&
pass && reduce_threadwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>( reduce_threadwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
// for testing float // for testing float
......
...@@ -89,27 +89,25 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -89,27 +89,25 @@ int reduce_threadwise_impl(bool do_verification,
return (-1); return (-1);
}; };
using PassThrough = tensor_operation::element_wise::PassThrough;
// using Add = tensor_operation::element_wise::Add;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation = PassThrough;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation; using OutElementwiseOperation = PassThrough;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using InOutDataTypeInDevice = InOutDataType; using InOutDataTypeInDevice = InOutDataType;
using DeviceReduceInstance = using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceThreadWiseMultiD<InOutDataTypeInDevice, ck::tensor_operation::device::DeviceReduceThreadWiseMultiD<InOutDataTypeInDevice,
ck::Tuple<>,
AccDataType, AccDataType,
InOutDataTypeInDevice, InOutDataTypeInDevice,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, OutElementwiseOperation,
PropagateNan,
OutputIndex,
false,
false, // HaveIndexInputIfOutputIndex
256, // BlockSize 256, // BlockSize
4, // MThreadSliceSize 4, // MThreadSliceSize
1, // KThreadSliceSize 1, // KThreadSliceSize
...@@ -173,7 +171,6 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -173,7 +171,6 @@ int reduce_threadwise_impl(bool do_verification,
DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize()); DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize()); DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.mData.data());
if(beta != 0.0f) if(beta != 0.0f)
...@@ -187,11 +184,7 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -187,11 +184,7 @@ int reduce_threadwise_impl(bool do_verification,
DeviceMem out_index_dev(indicesSizeInBytes); DeviceMem out_index_dev(indicesSizeInBytes);
InElementwiseOperation in_elementwise_op; InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op; OutElementwiseOperation out_elementwise_op;
std::tie(in_elementwise_op, acc_elementwise_op) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length));
std::array<index_t, Rank> arrInLengths; std::array<index_t, Rank> arrInLengths;
std::array<index_t, Rank> arrInStrides; std::array<index_t, Rank> arrInStrides;
...@@ -213,7 +206,7 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -213,7 +206,7 @@ int reduce_threadwise_impl(bool do_verification,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, PassThrough,
PropagateNan, PropagateNan,
OutputIndex>; OutputIndex>;
...@@ -231,7 +224,7 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -231,7 +224,7 @@ int reduce_threadwise_impl(bool do_verification,
out_ref.mData.data(), out_ref.mData.data(),
out_indices_ref.mData.data(), out_indices_ref.mData.data(),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); PassThrough{});
if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
...@@ -249,17 +242,16 @@ int reduce_threadwise_impl(bool do_verification, ...@@ -249,17 +242,16 @@ int reduce_threadwise_impl(bool do_verification,
auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths,
arrInStrides, arrInStrides,
{},
{},
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, {},
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
out_index_dev.GetDeviceBuffer(),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); out_elementwise_op);
if(!reduce.IsSupportedArgument(argument_ptr.get())) if(!reduce.IsSupportedArgument(argument_ptr.get()))
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceReduceMultiD : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const void* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
void* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
...@@ -19,33 +19,29 @@ namespace tensor_operation { ...@@ -19,33 +19,29 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename InDataType, template <typename InDataType,
typename DsDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename OutElementwiseOperation,
bool PropagateNan,
bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, OutElementwiseOperation>
PropagateNan,
OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
...@@ -57,10 +53,10 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -57,10 +53,10 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
...@@ -159,34 +155,69 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -159,34 +155,69 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto
MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
{
return generate_tuple(
[&](auto i) {
return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i],
DsStrides[i]);
},
Number<NumDTensor>{});
}
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise_multi_d<InDataType,
DsDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
InMemoryDataOperationEnum::Set,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
using DsGridPointer = typename GridwiseReduce::DsGridPointer;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, Rank> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides, const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const InDataType* in_dev, const InDataType* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_index_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) const OutElementwiseOperation out_elementwise_op)
: outLengths_{outLengths}, : DsLengths_{DsLengths},
DsStrides_{DsStrides},
outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} out_elementwise_op_{out_elementwise_op}
{ {
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
...@@ -201,22 +232,33 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -201,22 +232,33 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
});
ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
} }
std::array<index_t, Rank> inLengths_; std::array<index_t, Rank> inLengths_;
std::array<index_t, Rank> inStrides_; std::array<index_t, Rank> inStrides_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
std::array<index_t, NumDstDim> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::array<index_t, NumDstDim> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_index_dev_;
DsGridPointer p_ds_grid_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_; OutElementwiseOperation out_elementwise_op_;
DsGridDesc_M ds_grid_desc_m_;
index_t invariant_lowest_length; index_t invariant_lowest_length;
index_t reduce_lowest_length; index_t reduce_lowest_length;
...@@ -236,44 +278,8 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -236,44 +278,8 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
const auto out_grid_desc_m = const auto out_grid_desc_m =
DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
const auto ds_grid_desc_m = generate_tuple(
[&](auto i) {
ignore = i;
return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_,
arg.outStrides_);
},
Number<1>{});
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using DsGridDesc_M = decltype(ds_grid_desc_m);
float avg_time = 0; float avg_time = 0;
using Add = tensor_operation::element_wise::Add;
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise_multi_d<InDataType,
Tuple<OutDataType>,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
Add,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
using DsGridPointer = typename GridwiseReduce::DsGridPointer;
const auto kernel = kernel_reduce_threadwise_multi_d<GridwiseReduce, const auto kernel = kernel_reduce_threadwise_multi_d<GridwiseReduce,
InDataType, InDataType,
OutDataType, OutDataType,
...@@ -282,23 +288,21 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -282,23 +288,21 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
DsGridDesc_M, DsGridDesc_M,
OutGridDesc_M, OutGridDesc_M,
InElementwiseOperation, InElementwiseOperation,
Add, OutElementwiseOperation,
DsGridPointer>; DsGridPointer>;
DsGridPointer p_ds_grid_;
avg_time = launch_and_time_kernel(stream_config, avg_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
in_grid_desc_m_k, in_grid_desc_m_k,
ds_grid_desc_m, arg.ds_grid_desc_m_,
out_grid_desc_m, out_grid_desc_m,
arg.in_elementwise_op_, arg.in_elementwise_op_,
Add{}, arg.out_elementwise_op_,
arg.in_dev_, arg.in_dev_,
p_ds_grid_, arg.p_ds_grid_,
arg.out_dev_); arg.out_dev_);
return (avg_time); return (avg_time);
...@@ -356,32 +360,29 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType, ...@@ -356,32 +360,29 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduce<InDataType,
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides, const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const std::array<const void*, NumDTensor> ds_dev,
void* out_dev, void* out_dev,
void* out_index_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override const OutElementwiseOperation out_elementwise_op) override
{ {
(void)in_index_dev;
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
DsLengths,
DsStrides,
outLengths, outLengths,
outStrides, outStrides,
reduceDims, reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
ds_dev,
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_index_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); out_elementwise_op);
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -55,7 +55,6 @@ template <typename InDataType, ...@@ -55,7 +55,6 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation, InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
index_t BlockSize, index_t BlockSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
...@@ -110,7 +109,7 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d ...@@ -110,7 +109,7 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M, ThreadReduceDstDesc_M,
ReduceOperation, ReduceOperation,
PropagateNan>; false>;
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>(); const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
...@@ -189,8 +188,6 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d ...@@ -189,8 +188,6 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
auto ds_global_buf = generate_tuple( auto ds_global_buf = generate_tuple(
[&](auto I) { [&](auto I) {
// static_assert(ds_grid_desc_m[I].GetNumOfDimension() == 1, "");
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize()); p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
}, },
...@@ -208,9 +205,9 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d ...@@ -208,9 +205,9 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
Sequence<MThreadSliceSize>, // SliceLengths Sequence<MThreadSliceSize>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
OutDstVectorSize, 1,
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{ true>{
ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)}; ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)};
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment