Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
......@@ -17,8 +17,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
......@@ -39,13 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
......@@ -62,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims)
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
......@@ -84,7 +89,10 @@ struct DeviceReduceMultiBlockAtomicAdd
}
else
{
const auto toReduceDimLengths =
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
......@@ -92,23 +100,24 @@ struct DeviceReduceMultiBlockAtomicAdd
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded =
transform_tensor_descriptor(in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M),
make_right_pad_transform(innerLen, inPad_K)),
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -118,24 +127,25 @@ struct DeviceReduceMultiBlockAtomicAdd
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)),
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
......@@ -143,43 +153,44 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
......@@ -212,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd
std::vector<int> outStrides_;
AccDataType alpha_;
OutDataType beta_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
......@@ -330,18 +341,22 @@ struct DeviceReduceMultiBlockAtomicAdd
if constexpr(InSrcVectorDim == 0)
{
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
{
return (false);
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
......@@ -367,23 +382,25 @@ struct DeviceReduceMultiBlockAtomicAdd
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
......
......@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
......@@ -37,26 +37,35 @@ struct DeviceReduceMultiBlockPartialReduce
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) override
static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{
size_t invariant_total_length;
size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1;
while(true)
......@@ -64,8 +73,7 @@ struct DeviceReduceMultiBlockPartialReduce
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
......@@ -74,11 +82,12 @@ struct DeviceReduceMultiBlockPartialReduce
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
size_t workspace_size = invariant_total_length * blkGroupSize;
long_index_t workspace_size = invariant_total_length * blkGroupSize;
size_t wsSizeInBytes =
!NeedIndices ? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int)) + 64 + sizeof(int);
long_index_t wsSizeInBytes =
!NeedIndices
? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes);
};
......@@ -90,18 +99,18 @@ struct DeviceReduceMultiBlockPartialReduce
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims)
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
......@@ -112,7 +121,10 @@ struct DeviceReduceMultiBlockPartialReduce
}
else
{
const auto toReduceDimLengths =
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
......@@ -120,38 +132,41 @@ struct DeviceReduceMultiBlockPartialReduce
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded =
transform_tensor_descriptor(in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M),
make_right_pad_transform(innerLen, inPad_K)),
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeWorkspace2dDescriptor(int outerLen, int blkGroupSize)
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{
auto ws_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(outerLen, blkGroupSize));
auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, wsPad),
make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -161,43 +176,43 @@ struct DeviceReduceMultiBlockPartialReduce
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides,
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev},
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev}
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
......@@ -205,8 +220,7 @@ struct DeviceReduceMultiBlockPartialReduce
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
......@@ -236,7 +250,7 @@ struct DeviceReduceMultiBlockPartialReduce
std::vector<int> outStrides_;
AccDataType alpha_;
OutDataType beta_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
......@@ -334,18 +348,22 @@ struct DeviceReduceMultiBlockPartialReduce
if constexpr(InSrcVectorDim == 0)
{
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
{
return (false);
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
......@@ -368,23 +386,25 @@ struct DeviceReduceMultiBlockPartialReduce
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
......
......@@ -16,7 +16,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
typename ReduceDims,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
......@@ -36,15 +36,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -52,18 +57,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims)
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
......@@ -74,7 +79,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}
else
{
const auto toReduceDimLengths =
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
......@@ -82,22 +90,24 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded =
transform_tensor_descriptor(in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M),
make_right_pad_transform(innerLen, inPad_K)),
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -107,24 +117,25 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)),
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
......@@ -132,42 +143,45 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
......@@ -179,7 +193,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
std::vector<int> outStrides_;
AccDataType alpha_;
OutDataType beta_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
......@@ -272,18 +286,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if constexpr(InSrcVectorDim == 0)
{
if constexpr(InvariantDims::Size() == 0)
if constexpr(NumInvariantDim == 0)
{
return (false);
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
......@@ -304,23 +322,25 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op) override
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
......
......@@ -5,10 +5,16 @@ namespace ck {
namespace tensor_operation {
namespace device {
enum GemmSpecialization_t
enum struct GemmSpecialization
{
Default,
MPadding,
NPadding,
KPadding,
MNPadding,
MKPadding,
NKPadding,
MNKPadding,
};
} // namespace device
......
......@@ -37,11 +37,11 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels.
template <typename T, ReduceTensorOp_t Op>
template <typename T, ReduceTensorOp Op>
struct reduce_binary_operator;
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
struct reduce_binary_operator<T, ReduceTensorOp::ADD>
{
using opType = reduce::Add<T>;
using dataType = T;
......@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
struct reduce_binary_operator<T, ReduceTensorOp::MUL>
{
using opType = reduce::Mul<T>;
using dataType = T;
......@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
struct reduce_binary_operator<T, ReduceTensorOp::MIN>
{
using opType = reduce::Min<T>;
using dataType = T;
......@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
struct reduce_binary_operator<T, ReduceTensorOp::MAX>
{
using opType = reduce::Max<T>;
using dataType = T;
......@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
{
using opType = reduce::AMax<T>;
using dataType = T;
......@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
struct reduce_binary_operator<T, ReduceTensorOp::AVG>
{
using opType = reduce::Add<T>;
using dataType = T;
......@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
{
using opType = reduce::Add<T>;
using dataType = T;
......@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
{
using opType = reduce::Add<T>;
using dataType = T;
......@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively
template <typename T, ReduceTensorOp_t Op, bool IsFirstReduce, bool IsLastReduce>
template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
struct reduce_unary_operator
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
......@@ -123,42 +123,42 @@ struct reduce_unary_operator
};
template <typename T, bool IsFirstReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, IsFirstReduce, true>
struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
......
#ifndef TENSOR_LAYOUT_HPP
#define TENSOR_LAYOUT_HPP
#pragma once
namespace ck {
namespace tensor_layout {
......@@ -85,6 +84,7 @@ struct NKHW : public BaseTensorLayout
static constexpr const char* name = "NKHW";
};
// 3D Conv
struct NDHWC : public BaseTensorLayout
{
static constexpr const char* name = "NDHWC";
......@@ -99,6 +99,20 @@ struct NDHWK : public BaseTensorLayout
{
static constexpr const char* name = "NDHWK";
};
struct NCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NCDHW";
};
struct KCZYX : public BaseTensorLayout
{
static constexpr const char* name = "KCZYX";
};
struct NKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NKDHW";
};
} // namespace convolution
......@@ -113,4 +127,3 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
} // namespace tensor_layout
} // namespace ck
#endif
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
#include "data_type.hpp"
#pragma once
#include "data_type.hpp"
namespace ck {
......@@ -19,6 +16,8 @@ struct PassThrough
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
};
struct Add
......@@ -239,6 +238,24 @@ struct UnaryIdentic<int32_t, int32_t, false>
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<int32_t, int32_t, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int8_t, int8_t, false>
{
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
......@@ -311,6 +328,19 @@ struct UnaryAbs<double, double>
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
};
template <>
struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
};
template <typename Y, typename X>
struct UnarySqrt;
......@@ -333,4 +363,3 @@ struct UnarySqrt<double, double>
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif
#pragma once
#include "data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -31,8 +31,10 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
......@@ -52,14 +54,16 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
GridwiseReduction::Run(in_grid_desc_m_k,
constexpr bool IsSecondCall = false;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
......@@ -102,14 +106,16 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
GridwiseReduction::Run(in_grid_desc_m_k,
constexpr bool IsSecondCall = true;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
......@@ -156,64 +162,88 @@ template <typename InDataType,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
......@@ -221,28 +251,28 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
......@@ -260,45 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
});
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
......@@ -315,7 +326,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
......@@ -340,7 +351,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
......@@ -350,18 +361,18 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
......@@ -374,19 +385,17 @@ struct GridwiseReduction_mk_to_m_blockwise
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
......@@ -398,62 +407,61 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
accu_index_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
......@@ -479,56 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
......@@ -537,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
......@@ -556,7 +543,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
......@@ -581,7 +568,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
......@@ -591,36 +578,36 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<index_t>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<index_t>{});
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
......@@ -642,19 +629,20 @@ struct GridwiseReduction_mk_to_m_blockwise
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
......@@ -666,90 +654,86 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global,
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
accu_index_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
IndexDataType,
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
// index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
......@@ -774,56 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
// indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
......@@ -841,7 +802,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
......@@ -866,7 +827,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
......@@ -876,36 +837,36 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<IndexDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<index_t>{});
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
......
......@@ -30,8 +30,10 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
......@@ -84,24 +86,46 @@ template <typename InDataType,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -121,23 +145,20 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
......@@ -145,12 +166,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
......@@ -158,13 +179,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
......@@ -185,49 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
blockwise_reduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
});
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
......@@ -245,18 +246,18 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
......
......@@ -23,15 +23,17 @@
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
......@@ -101,15 +103,34 @@ template <typename InDataType,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
{
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer1dDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -124,17 +145,18 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global)
{
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
AccDataType,
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
......@@ -142,25 +164,22 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
......@@ -168,12 +187,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
......@@ -181,13 +200,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
......@@ -208,47 +226,29 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
__syncthreads();
BlockwiseReduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
});
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
if(thread_k_cluster_id == 0)
{
......@@ -257,19 +257,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
workspace_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp<AccDataType>{});
PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc,
make_tuple(I0, I0),
......@@ -290,13 +290,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
IndexDataType* const __restrict__ p_ws_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
AccDataType,
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
......@@ -310,48 +308,44 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
accu_index_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
......@@ -359,13 +353,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
......@@ -394,56 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
......@@ -463,38 +436,38 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
workspace_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp<AccDataType>{});
PassThroughOp{});
auto threadwise_workspace_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
WorkspaceDesc_M_K,
PassThroughOp<IndexDataType>,
PassThroughOp,
Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(
workspace_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp<IndexDataType>{});
PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc,
make_tuple(I0, I0),
......
......@@ -30,7 +30,9 @@
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
......@@ -50,7 +52,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
{
......@@ -101,8 +103,20 @@ template <typename InDataType,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_threadwise
{
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
......@@ -112,30 +126,29 @@ struct GridwiseReduction_mk_to_m_threadwise
const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
{
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
......@@ -147,17 +160,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
......@@ -170,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
......@@ -195,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha;
});
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
{
......@@ -215,7 +224,7 @@ struct GridwiseReduction_mk_to_m_threadwise
true>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
......@@ -225,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
};
......@@ -235,17 +244,17 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
......@@ -257,34 +266,39 @@ struct GridwiseReduction_mk_to_m_threadwise
const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType beta,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
{
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
AccDataType,
IndexDataType>;
PropagateNan>;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr,
AccDataType,
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf;
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
accu_index_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
......@@ -299,17 +313,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
......@@ -321,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
AccumulationWithIndex::Calculate(accu_value_buf(I),
in_thread_buf[offset],
accu_index_buf(I),
indexStart + J);
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize;
......@@ -354,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha;
});
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
{
......@@ -374,7 +384,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m,
......@@ -384,7 +394,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
};
......@@ -394,34 +404,34 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<AccDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<AccDataType>{});
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp<IndexDataType>,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<IndexDataType>{});
PassThroughOp{});
threadwise_dst_val_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
......
#ifndef CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseBatchedGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_G_K0_M_K1,
typename BGridDesc_G_K0_N_K1,
typename CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_G_K0_M_K1 a_grid_desc_g_k0_m_k1,
const BGridDesc_G_K0_N_K1 b_grid_desc_g_k0_n_k1,
const CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
__shared__ char p_shared[GridwiseBatchedGemm::GetSharedMemoryNumberOfByte()];
GridwiseBatchedGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_g_k0_m_k1,
b_grid_desc_g_k0_n_k1,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_G_K0_M_K1,
typename BGridDesc_G_K0_N_K1,
typename CGridDesc_G_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_G_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_G_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseBatchedGemm_gk0mk1_gk0nk1_gmn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr auto
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_g_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(I1, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_g_k0_m_k1;
}
__host__ __device__ static constexpr auto
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_g_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(I1, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
return b_block_desc_g_k0_n_k1;
}
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto a_block_desc_g_k0_m_k1 =
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
constexpr auto K0 = a_block_desc_g_k0_m_k1.GetLength(I1);
constexpr auto M = a_block_desc_g_k0_m_k1.GetLength(I2);
constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor(
a_block_desc_g_k0_m_k1,
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(K0),
make_pass_through_transform(M),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto b_block_desc_g_k0_n_k1 =
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
constexpr auto K0 = b_block_desc_g_k0_n_k1.GetLength(I1);
constexpr auto N = b_block_desc_g_k0_n_k1.GetLength(I2);
constexpr auto b_block_desc_k0_n_k1 = transform_tensor_descriptor(
b_block_desc_g_k0_n_k1,
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(K0),
make_pass_through_transform(N),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_g_k0_m_k1 =
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_g_k0_n_k1 =
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_g_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_g_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_G_K0_M_K1& a_grid_desc_g_k0_m_k1,
const BGridDesc_G_K0_N_K1& b_grid_desc_g_k0_n_k1,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
// const auto G = a_grid_desc_g_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_g_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_g_k0_m_k1.GetLength(I2);
const auto N = b_grid_desc_g_k0_n_k1.GetLength(I2);
if(!(M == c_grid_desc_g_m_n.GetLength(I1) && N == c_grid_desc_g_m_n.GetLength(I2) &&
K0 == b_grid_desc_g_k0_n_k1.GetLength(I1) &&
K1 == a_grid_desc_g_k0_m_k1.GetLength(I3) &&
K1 == b_grid_desc_g_k0_n_k1.GetLength(I3)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const index_t grid_size = G * (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>;
return BlockwiseGemm::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_g_m_n);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_G_M_N& c_grid_desc_g_m_n, index_t M01, index_t N01)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_g_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_g_m0_n0_block_cluster_adaptor;
}
using CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_G_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_G_M_N{}, 1, 1));
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_G_K0_M_K1& a_grid_desc_g_k0_m_k1,
const BGridDesc_G_K0_N_K1& b_grid_desc_g_k0_n_k1,
const CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_g_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_g_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_g_k0_m_k1.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t g_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_g_k0_m_k1 =
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_g_k0_n_k1 =
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_g_k0_m_k1),
decltype(a_block_desc_g_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc_g_k0_m_k1,
make_multi_index(g_idx_on_grid, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_g_k0_m_k1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_g_k0_n_k1),
decltype(b_block_desc_g_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_g_k0_n_k1,
make_multi_index(g_idx_on_grid, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_g_k0_n_k1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_g_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_g_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_g_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_g_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_g_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_g_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_g_k0_n_k1, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_g_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_g_k0_n_k1, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_grid_desc_g_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc_g_k0_n_k1, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_g_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_g_k0_n_k1, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr auto G = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto M0 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto N0 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto M1 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto N1 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M2 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M3 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto M4 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto N2 = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I8);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<I1, M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(g_idx_on_grid,
m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf);
}
}
};
} // namespace ck
#endif
......@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
......@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
......@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
......@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
......@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
......@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
......
......@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
......@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
......@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
......@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
......@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
......@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize());
......
......@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
......@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N]
......@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
......@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
......@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
......@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
......
......@@ -15,7 +15,7 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
......@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3;
......@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
......@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
......@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
......
......@@ -20,7 +20,7 @@ template <typename GridwiseGemm,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -50,7 +50,7 @@ __global__ void
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{});
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
......@@ -62,7 +62,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -94,7 +94,7 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{});
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
......@@ -106,7 +106,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -140,14 +140,14 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{});
integral_constant<ActivTypeEnum, ActivType>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
......@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
bias_k0_k1_thread_desc.GetElementSpaceSize(),
true>
......@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
});
}
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum_t activ_type_>
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf,
const CThreadDesc_K1_N_H2_W2&,
integral_constant<ActivTypeEnum_t, activ_type_>)
integral_constant<ActivTypeEnum, activ_type_>)
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
......@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1,
Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
......@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
......@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1,
Number<WoPerThreadx2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
......@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Add,
InMemoryDataOperationEnum::Add,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
......@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
......@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
0,
0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
//// register allocation for output
// StaticBuffer<AddressSpaceEnum_t::Vgpr,
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true>
......@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true>
......@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
......@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
......@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>)
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
......@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivMaxpool(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
......@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>)
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
......@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType>
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivResizeAdd(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
......@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>)
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
......
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatD,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainK0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_d0_grid,
p_d1_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
block_2_ctile_map);
}
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename FloatReduceAcc,
typename FloatD,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false;
// check NumGemmKPrefetchStage
if constexpr(NumGemmKPrefetchStage == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumGemmKPrefetchStage == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K / KPerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const D0ReduceOperation& d0_reduce_op,
const D1ReduceOperation& d1_reduce_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
0,
"wrong!");
constexpr index_t mreduce_per_thread =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
constexpr index_t nreduce_per_thread =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
Sequence<mreduce_per_thread, nreduce_per_thread>{};
// VGPR c_reduce_thread_desc_mperblock_nperblock
constexpr auto c_reduce_thread_desc_mperblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx =
c_reduce_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatCShuffle,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
// reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatCShuffle,
FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation,
1,
false>{d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
ck::tensor_operation::element_wise::PassThrough{}};
auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global;
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// reduce
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue();
FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue();
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]);
d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]);
});
constexpr index_t out_offset =
d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im));
d0_thread_buf(Number<out_offset>{}) = d0_acc;
d1_thread_buf(Number<out_offset>{}) = d1_acc;
});
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d0_thread_buf,
d_grid_desc_mblock_mperblock,
d0_grid_buf);
d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d1_thread_buf,
d_grid_desc_mblock_mperblock,
d1_grid_buf);
}
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on D0
d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
}
}
};
} // namespace ck
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
bool HasMainK0BlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false;
// check NumGemmKPrefetchStage
if constexpr(NumGemmKPrefetchStage == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumGemmKPrefetchStage == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K / KPerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment