Commit b2290854 authored by rocking's avatar rocking
Browse files

Merge commit '3e6c2610' into gemm_norm

parents 253f7ef2 3e6c2610
......@@ -9,6 +9,7 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
#include "device_prop.hpp"
namespace ck {
namespace tensor_operation {
......@@ -558,6 +559,11 @@ struct DeviceGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
......
......@@ -12,6 +12,7 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gemm_specialization.hpp"
#include "device_prop.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 1
......@@ -528,6 +529,11 @@ struct DeviceGemmXdlSplitK
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
......
......@@ -346,7 +346,6 @@ struct DeviceGroupedGemmXdl
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
private:
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t BlockStart_;
};
......@@ -418,9 +417,8 @@ struct DeviceGroupedGemmXdl
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp =
typename GroupedGemmBlock2CTileMap::UnderlyingBlock2CTileMap(
c_grid_desc_m_n_, M01, N01)
.CalculateGridSize(c_grid_desc_m_n_);
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0)
.block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
......
......@@ -17,7 +17,7 @@ template <typename InDataType,
typename OutDataType,
typename AccDataType,
ck::ReduceTensorOp ReduceOpId,
bool NeedIndices,
bool OuputIndex,
ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize,
......@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
static constexpr bool BetaIsZero = true;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
......@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
false, // propagate_nan
BetaIsZero,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
NeedIndices,
OuputIndex,
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
......@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
nullptr,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
......
......@@ -16,35 +16,18 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator
{
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{
(void)inLengths;
(void)reduceDims;
return (0);
};
virtual bool HasFurtherCall() { return (false); };
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
{
(void)argPtr;
return (std::vector<int>{0, 0});
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
void* out_index_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) = 0;
......
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_BLOCKWISE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
nullptr,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWiseSecondCall
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static_assert(
std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1];
invariant_lowest_length = inLengths[0];
reduce_lowest_length = inLengths[1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
arg.workspace_indices_dev_,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
(void)reduceDims;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -14,13 +14,13 @@ namespace device {
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// of reduce dims
template <int Rank, int NumReduceDim>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t invariant_total_length = 1;
size_t reduce_total_length = 1;
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
constexpr int NumInvariantDim = Rank - NumReduceDim;
......@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
// helper functions using variadic template arguments
template <index_t... Ns>
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
{
return make_tuple(static_cast<index_t>(lengths[Ns])...);
};
template <index_t arraySize>
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>)
auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
{
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
......@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
};
template <index_t Rank, index_t NumReduceDim>
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides,
const std::vector<int>& reduceDims)
std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
const std::vector<int>& reduceDims)
{
std::vector<int> newLengthsStrides;
std::vector<index_t> newLengthsStrides;
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
#define DEVICE_REDUCE_MULTIBLOCK_HPP
#include <iostream>
#include <sstream>
......@@ -7,8 +7,9 @@
#include "device_base.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace ck {
namespace tensor_operation {
......@@ -22,8 +23,10 @@ template <typename InDataType,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
bool NeedIndices,
bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -32,8 +35,7 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlockAtomicAdd
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd =
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(!NeedIndices && support_AtomicAdd,
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when "
"having float/double output type!");
static_assert(
!use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
static_assert(
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int blkGroupSize,
int kBlockTileIterations)
int numBlockTileIteration)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
......@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
......@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
......@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
return (out_grid_desc_m_padded);
};
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(length, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
IndexDataType* out_index_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
in_index_dev_{in_index_dev},
out_dev_{out_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
......@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
if constexpr(use_multiblock)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
int iterations = 1;
while(true)
{
int testBlkGroupSize =
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
iterations++;
};
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
iterations++;
};
kBlockTileIterations = iterations;
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration =
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
......@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
std::vector<index_t> outLengths_;
std::vector<index_t> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
const IndexDataType* in_index_dev_;
OutDataType* out_dev_;
IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t invariant_lowest_length;
index_t reduce_lowest_length;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
index_t blkGroupSize;
index_t kBlockTileIterations;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
size_t gridSize_pre;
......@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor(
const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m =
DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce =
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
OutMemoryDataOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
OutputIndex,
HaveIndexInput,
InDataType,
OutDataType,
AccDataType,
int32_t,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m,
arg.out_dev_,
static_cast<OutDataType>(0.0f));
if constexpr(use_multiblock)
{
const auto zeroVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
const auto kernel_pre =
kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
};
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
......@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
arg.in_index_dev_,
arg.beta_,
arg.out_dev_,
arg.out_index_dev_);
return avg_time;
}
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock)
{
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
};
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
......@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
return (false);
};
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with small reduce_total_length should be handled by the BlockWise method
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
return (false);
if constexpr(use_multiblock)
{
// blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if(pArg->blkGroupSize == 1)
return (false);
// This is very strong restriction, but needed to avoid some failure
if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
return (false);
// This is very strong restriction, but needed to avoid some failure
if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
return (false);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
};
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
void* out_index_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
......@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<const IndexDataType*>(in_index_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
static_cast<IndexDataType*>(out_index_dev),
in_elementwise_op,
acc_elementwise_op);
};
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlockPartialReduce
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{
size_t invariant_total_length;
size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
long_index_t workspace_size = invariant_total_length * blkGroupSize;
long_index_t wsSizeInBytes =
!NeedIndices
? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes);
};
bool HasFurtherCall() override { return (true); };
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{
auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (ws_desc_m_k_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
kBlockTileIterations = iterations;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<int*>(
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
AccDataType* workspace_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t blkGroupSize;
index_t kBlockTileIterations;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
arg.invariant_total_length, arg.blkGroupSize);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
using GridwiseReduce =
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
NeedIndices,
InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
ws_desc_m_k,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.in_dev_,
arg.workspace_dev_,
arg.workspace_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(OutDstVectorSize != 1)
return (false);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// cases with small reduce_total_length should be handled by the BlockWise method
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
return (false);
return (true);
};
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return (
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -6,6 +6,7 @@
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_2d_reduction_threadwise.hpp"
namespace ck {
......@@ -19,22 +20,19 @@ template <typename InDataType,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
......@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
......@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
......@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
IndexDataType* out_index_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op)
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
......@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
std::vector<index_t> outLengths_;
std::vector<index_t> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation acc_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t invariant_lowest_length;
index_t reduce_lowest_length;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize;
};
......@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
NeedIndices,
OutputIndex,
HaveIndexInput,
InDataType,
OutDataType,
AccDataType,
......@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
OutElementwiseOperation>;
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
nullptr,
arg.beta_,
arg.out_dev_,
arg.out_indices_dev_);
arg.out_index_dev_);
return (avg_time);
};
......@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// TODO: remove this. Should return true, as long as this DeviceOP instance support this
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
// better performance
// cases with big reduce_total_length should be handled by Blockwise kernel
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false);
......@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
void* out_index_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) override
const AccElementwiseOperation acc_elementwise_op) override
{
(void)in_index_dev;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
......@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
static_cast<IndexDataType*>(out_index_dev),
in_elementwise_op,
acc_elementwise_op);
};
......@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
auto str = std::stringstream();
// clang-format off
str << "DeviceReducceThreadWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "DeviceReduceThreadWise<" << BlockSize << ",";
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
......
......@@ -8,6 +8,237 @@
namespace ck {
// Rows of column-vectors
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_M00_N0_M01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1)
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const index_t grid_size = M00 * M01_ * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
if(M0 % M01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ __device__ static constexpr auto
GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1),
make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_n0_m01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
index_t M01_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
UnderlyingMap underlying_map_;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8,
index_t KSplit = 1)
: M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0 * KSplit_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
index_t KSplit_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// Blocks of row-vectors
template <index_t MPerBlock,
index_t NPerBlock,
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = false;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = true;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
bool PropagateNan,
bool BetaIsZero,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
__device__ static void
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
src_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
src_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
};
} // namespace ck
#endif
......@@ -23,75 +23,86 @@
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
bool OutputIndex,
bool HaveIndexInput,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename WorkspaceDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
const WorkspaceDesc_M_K workspace_desc_m_k,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global,
AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global)
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{
if constexpr(!NeedIndices)
if constexpr(!OutputIndex)
{
(void)p_in_index_global;
(void)p_out_index_global;
GridwiseReduction::Run(in_grid_desc_m_k,
workspace_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
p_src_global,
p_ws_values_global,
p_ws_indices_global);
alpha,
p_in_value_global,
beta,
p_out_value_global);
}
else
{
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
workspace_desc_m_k,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
p_src_global,
p_ws_values_global,
p_ws_indices_global);
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
num_k_block_tile_iteration,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
};
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename WorkspaceDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
......@@ -101,14 +112,13 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
struct GridwiseReduction_mk_to_m_multiblock
{
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
......@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
......@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global,
AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global)
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
......@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
......@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
// Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions.
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_workspace_store =
if(block_group_size == 0 && !float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType,
OutDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
OutMemoryDataOperation,
1,
true>(
workspace_desc_m_k,
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc,
make_tuple(I0, I0),
accu_value_buf,
workspace_desc_m_k,
workspace_global_buf);
threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
}
};
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global,
AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global)
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
......@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType,
IndexDataType>;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
......@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
......@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
......@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t indexOffset = block_local_id * reduceSizePerBlock;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
index_t reducedTiles = 0;
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
if constexpr(HaveIndexInput)
{
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
}
else
{
index_t indexOffset = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
};
indexOffset += K_BlockTileSize;
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_workspace_val_store =
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType,
OutDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
workspace_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_workspace_idx_store =
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
workspace_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc,
make_tuple(I0, I0),
accu_value_buf,
workspace_desc_m_k,
workspace_global_val_buf);
threadwise_workspace_idx_store.Run(reduced_data_desc,
make_tuple(I0, I0),
accu_index_buf,
workspace_desc_m_k,
workspace_global_idx_buf);
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
};
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_global,
p_out_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
};
} // namespace ck
#endif
......@@ -37,7 +37,8 @@
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
bool OutputIndex,
bool HaveIndexInput,
typename InDataType,
typename OutDataType,
typename AccDataType,
......@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{
if constexpr(!NeedIndices)
if constexpr(!OutputIndex)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
p_in_value_global,
beta,
p_out_global,
p_indices_global);
p_out_value_global);
}
else
{
GridwiseReduction::RunWithIndices(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_indices_global);
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
};
};
......@@ -91,11 +93,9 @@ template <typename InDataType,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan,
bool BetaIsZero,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
......@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
OutDataType* const __restrict__ p_out_value_global)
{
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
......@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
(void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
......@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
......@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
......@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
if(!float_equal_zero{}(beta))
{
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
1,
1,
true>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
dst_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
1,
1,
true>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
dst_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
OutMemoryDataOperation,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
};
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType,
......@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
......@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t indexStart = 0;
index_t reducedLength = 0;
do
if constexpr(HaveIndexInput)
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
}
else
{
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
});
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
};
// for indiced operation, acc_elementwise_op shoud do nothing
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
if(!float_equal_zero{}(beta))
{
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
1,
1,
false>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
1,
1,
false>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
auto threadwise_dst_val_store =
......@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
OutMemoryDataOperation,
1,
false>(
out_grid_desc_m,
......@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
OutMemoryDataOperation,
1,
false>(
out_grid_desc_m,
......
......@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M0,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
GridwiseBinEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
GridwiseBinEltwise::Run(
p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename GridDesc_M0,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor,
index_t ScalarPerVector>
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
p_b_global, b_grid_desc_m.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
p_c_global, c_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_store_global_offset};
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_m0),
GridDesc_M0,
decltype(thread_desc_m),
CGridDesc_M,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
CScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
c_grid_desc_m0, thread_store_global_offset, PassThrough{}};
c_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process ScalarPerVector elements
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
c_global_write.Run(thread_desc_m0,
c_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m0,
c_grid_desc_m,
c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
......
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_GEMM_V1R3_HPP
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_dl_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -43,10 +43,10 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
......@@ -56,12 +56,12 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
......@@ -83,13 +83,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r3
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc =
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc;
return a_grid_desc_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc =
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc;
return b_grid_desc_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
......@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
......@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
......@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc),
decltype(a_k0_m0_m1_k1_block_desc),
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
......@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_k0_m0_m1_k1_grid_desc,
true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc,
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc),
decltype(b_k0_n0_n1_k1_block_desc),
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
......@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_k0_n0_n1_k1_grid_desc,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc,
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
......@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0;
......@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
__syncthreads();
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
__syncthreads();
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
......@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
......@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
......@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridStepHacks{});
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
#endif
......@@ -306,7 +306,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
......
......@@ -259,7 +259,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment