"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "53488dad6e63236bd31aa4f6414c2fb12ecdc6d8"
Unverified Commit 9a8ee8a3 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Reduction for int8 and bfloat16 (#125)



* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction

* Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter

* Rename the folder name for the pool2d and reduce examples

* Update to reduction test scripts

* Add Readme for pool2d_fwd and reduce_blockwise examples

* Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX)

* Tiny fix in reduce profiler and tiny update in reduce testing scripts

* Tiny fix in testing script profile_reduce_no_index.sh

* Tiny fix in testing script profile_reduce_no_index.sh

* Add support for bfp16 reduction (using bhalf_t = ushort)

* Tiny fix in amd_buffer_addressing.hpp

* Tiny change in script/profile_reduce_with_index.sh

* Use AccDataType for Beta value and use element_wise::PassThrough

* Use type_convert for type converting in host layer reduction

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming all NumReduceDims to NumReduceDim

* Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2

* Update to testing scripts to add bf16 support

* added more static_assert

* Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp

* Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations

* minor change

* Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass

* Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp

* Tiny fix in script/profile_reduce_no_index.sh

* Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims

* Generic renaming in host reduction and DeviceReduce layer

* Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances

* Use multi-thread and simplification for host Reduction implementation

* Add ctest for reduction

* Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/

* Update to the reduce CTest executables to enable default testing behavior when no command argument

* Renaming
Co-authored-by: default avatarJianfeng yan <jfyan008@gmail.com>
parent cb87b049
...@@ -37,7 +37,7 @@ cmake \ ...@@ -37,7 +37,7 @@ cmake \
```bash ```bash
# -D <xxx> : input 4-d tensor lengths # -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes) # -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=integer value, 2=decimal value) #arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: run kernel # of times (>1) #arg2: run kernel # of times (>1)
./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10 ./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
``` ```
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_blockwise.hpp"
#include "host_reduce_util.hpp" #include "host_reduce_util.hpp"
#include "host_generic_reduction.hpp" #include "host_reduction.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
using namespace ck; using namespace ck;
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
using InDataType = half_float::half; using InDataType = ck::half_t;
using OutDataType = half_float::half; using OutDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using kInDataType = ck::half_t; using HostInDataType = half_float::half;
using kOutDataType = ck::half_t; using HostOutDataType = half_float::half;
using kAccDataType = float; using HostAccDataType = float;
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -43,9 +43,9 @@ using InElementwiseOperation = ...@@ -43,9 +43,9 @@ using InElementwiseOperation =
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType, using DeviceReduceInstance = DeviceReduceBlockWise<InDataType,
kAccDataType, AccDataType,
kOutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
...@@ -135,6 +135,10 @@ class SimpleAppArgs ...@@ -135,6 +135,10 @@ class SimpleAppArgs
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction" "comparing with the host-based reduction"
<< std::endl; << std::endl;
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg2 -- number of repeats to run the kernel" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -263,20 +267,21 @@ int main(int argc, char* argv[]) ...@@ -263,20 +267,21 @@ int main(int argc, char* argv[])
{ {
switch(args.init_method) switch(args.init_method)
{ {
case 0: case 0: break;
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread); case 1:
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
if(beta != 0.0f) if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread); out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
break; break;
case 1: case 2:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
if(beta != 0.0f) if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread); out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
if(beta != 0.0f) if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread); out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
} }
if(beta != 0.0f) if(beta != 0.0f)
...@@ -293,17 +298,27 @@ int main(int argc, char* argv[]) ...@@ -293,17 +298,27 @@ int main(int argc, char* argv[])
if(beta != 0.0f) if(beta != 0.0f)
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.mData.data());
size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int) : 0; size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
DeviceMem out_indices_dev(indicesSizeInBytes); DeviceMem out_indices_dev(indicesSizeInBytes);
if(args.do_verification) if(args.do_verification)
{ {
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices> ReductionHost<HostInDataType,
HostAccDataType,
HostOutDataType,
ReduceOpId,
Rank,
NumReduceDim,
PropagateNan,
NeedIndices>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run( hostReduce.Run(alpha,
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); reinterpret_cast<const HostInDataType*>(in.mData.data()),
beta,
reinterpret_cast<HostOutDataType*>(out_ref.mData.data()),
out_indices_ref.mData.data());
}; };
const auto i_inLengths = to_int_vector(args.inLengths); const auto i_inLengths = to_int_vector(args.inLengths);
...@@ -313,7 +328,7 @@ int main(int argc, char* argv[]) ...@@ -313,7 +328,7 @@ int main(int argc, char* argv[])
auto reduce = DeviceReduceInstance{}; auto reduce = DeviceReduceInstance{};
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths); auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
DeviceMem ws_dev(wsSizeInBytes); DeviceMem ws_dev(wsSizeInBytes);
......
...@@ -36,7 +36,7 @@ cmake \ ...@@ -36,7 +36,7 @@ cmake \
## Run ```pool2d_fwd``` ## Run ```pool2d_fwd```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx #arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
./example/pool2d_fwd 1 1 10 ./example/pool2d_fwd 1 1 10
......
...@@ -236,8 +236,9 @@ int main(int argc, char* argv[]) ...@@ -236,8 +236,9 @@ int main(int argc, char* argv[])
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break; case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); break;
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
......
...@@ -16,9 +16,11 @@ namespace device { ...@@ -16,9 +16,11 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
virtual size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{ {
(void)inLengths; (void)inLengths;
(void)reduceDims;
return (0); return (0);
}; };
...@@ -32,19 +34,19 @@ struct DeviceReduce : public BaseOperator ...@@ -32,19 +34,19 @@ struct DeviceReduce : public BaseOperator
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) = 0; const AccElementwiseOperation acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -36,20 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -36,20 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -57,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -57,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides) const std::vector<int>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -79,6 +79,9 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -79,6 +79,9 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -93,18 +96,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -93,18 +96,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -112,44 +117,45 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -112,44 +117,45 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, inPad)), make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -160,21 +166,21 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -160,21 +166,21 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
{ {
(void)workspace_dev; (void)workspace_dev;
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
...@@ -186,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -186,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -278,18 +284,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -278,18 +284,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -308,19 +318,19 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -308,19 +318,19 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
...@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
...@@ -46,12 +50,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -46,12 +50,8 @@ struct DeviceReduceBlockWiseSecondCall
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"); "InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -65,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -65,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -84,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -84,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
...@@ -131,8 +134,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -131,8 +134,8 @@ struct DeviceReduceBlockWiseSecondCall
in_elementwise_op_(in_elementwise_op), in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op) acc_elementwise_op_(acc_elementwise_op)
{ {
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0]; invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1]; reduce_total_length = inLengths[1];
...@@ -159,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -159,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -268,19 +271,19 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -268,19 +271,19 @@ struct DeviceReduceBlockWiseSecondCall
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
(void)reduceDims; (void)reduceDims;
......
...@@ -12,38 +12,30 @@ namespace ck { ...@@ -12,38 +12,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// template <typename preUnaryOpType, typename posUnaryOpType> // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// using DeviceReducePtr = std::unique_ptr<DeviceReduce<preUnaryOpType, posUnaryOpType>>; // of reduce dims
template <int Rank, int NumReduceDim>
template <int Rank, typename ReduceDims>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t tensor_total_length = 1; size_t invariant_total_length = 1;
size_t reduce_total_length = 1; size_t reduce_total_length = 1;
static_for<0, ReduceDims::Size(), 1>{}(
[&](auto i) { reduce_total_length *= inLengths[ReduceDims::At(i)]; });
static_for<0, Rank, 1>{}([&](auto i) { tensor_total_length *= inLengths[i.value]; }); constexpr int NumInvariantDim = Rank - NumReduceDim;
return std::make_pair(tensor_total_length / reduce_total_length, reduce_total_length); for(int i = NumInvariantDim; i < Rank; i++)
}; reduce_total_length *= inLengths[i];
template <int x, typename Seq>
constexpr bool belong()
{
bool inside = false;
static_for<0, Seq::Size(), 1>{}([&](auto i) { inside = (inside || (x == Seq::At(i))); }); for(int i = 0; i < NumInvariantDim; i++)
invariant_total_length *= inLengths[i];
return (inside); return std::make_pair(invariant_total_length, reduce_total_length);
}; };
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
{ {
return make_tuple(static_cast<index_t>(lengths[Ns])...); return make_tuple(static_cast<index_t>(lengths[Ns])...);
}; };
...@@ -59,16 +51,12 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -59,16 +51,12 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
}; };
template <index_t Rank, index_t NumReduceDim> template <index_t Rank, index_t NumReduceDim>
static inline std::pair<std::vector<int>, std::vector<int>> std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides,
shuffle_tensor_dimensions(const std::vector<int>& dimLengths, const std::vector<int>& reduceDims)
const std::vector<int>& dimStrides,
const std::vector<int>& reduceDims)
{ {
std::vector<int> newDimLengths; std::vector<int> newLengthsStrides;
std::vector<int> newDimStrides;
assert(Rank == dimLengths.size() && Rank == dimStrides.size() && assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
NumReduceDim == reduceDims.size());
int reduceFlag = 0; int reduceFlag = 0;
...@@ -82,19 +70,17 @@ shuffle_tensor_dimensions(const std::vector<int>& dimLengths, ...@@ -82,19 +70,17 @@ shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
for(int i = 0; i < Rank; i++) for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0) if((reduceFlag & (1 << i)) == 0)
{ {
newDimLengths.push_back(dimLengths[i]); newLengthsStrides.push_back(origLengthsStrides[i]);
newDimStrides.push_back(dimStrides[i]);
}; };
// collect reduce dimensions // collect reduce dimensions
for(int i = 0; i < Rank; i++) for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) > 0) if((reduceFlag & (1 << i)) > 0)
{ {
newDimLengths.push_back(dimLengths[i]); newLengthsStrides.push_back(origLengthsStrides[i]);
newDimStrides.push_back(dimStrides[i]);
}; };
return std::make_pair(newDimLengths, newDimStrides); return newLengthsStrides;
}; };
} // namespace device } // namespace device
......
...@@ -39,18 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -39,18 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd = static constexpr bool support_AtomicAdd =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value; std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
...@@ -67,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -67,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int kBlockTileIterations)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -89,6 +89,9 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -89,6 +89,9 @@ struct DeviceReduceMultiBlockAtomicAdd
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -103,19 +106,20 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -103,19 +106,20 @@ struct DeviceReduceMultiBlockAtomicAdd
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -123,44 +127,45 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -123,44 +127,45 @@ struct DeviceReduceMultiBlockAtomicAdd
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -171,21 +176,21 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -171,21 +176,21 @@ struct DeviceReduceMultiBlockAtomicAdd
(void)out_indices_dev; (void)out_indices_dev;
(void)workspace_dev; (void)workspace_dev;
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -218,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -218,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -334,18 +339,22 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -334,18 +339,22 @@ struct DeviceReduceMultiBlockAtomicAdd
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -371,19 +380,19 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -371,19 +380,19 @@ struct DeviceReduceMultiBlockAtomicAdd
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
...@@ -37,31 +37,35 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -37,31 +37,35 @@ struct DeviceReduceMultiBlockPartialReduce
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!"); static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) override static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{ {
size_t invariant_total_length; size_t invariant_total_length;
size_t reduce_total_length; size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -69,8 +73,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -69,8 +73,7 @@ struct DeviceReduceMultiBlockPartialReduce
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 if(testBlkGroupSize <= MaxBlockGroupSize)
if(testBlkGroupSize <= 128)
break; break;
iterations++; iterations++;
...@@ -79,11 +82,12 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -79,11 +82,12 @@ struct DeviceReduceMultiBlockPartialReduce
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
size_t workspace_size = invariant_total_length * blkGroupSize; long_index_t workspace_size = invariant_total_length * blkGroupSize;
size_t wsSizeInBytes = long_index_t wsSizeInBytes =
!NeedIndices ? workspace_size * sizeof(AccDataType) !NeedIndices
: workspace_size * (sizeof(AccDataType) + sizeof(int)) + 64 + sizeof(int); ? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes); return (wsSizeInBytes);
}; };
...@@ -95,18 +99,18 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -95,18 +99,18 @@ struct DeviceReduceMultiBlockPartialReduce
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int kBlockTileIterations)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -117,6 +121,9 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -117,6 +121,9 @@ struct DeviceReduceMultiBlockPartialReduce
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -131,32 +138,35 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -131,32 +138,35 @@ struct DeviceReduceMultiBlockPartialReduce
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeWorkspace2dDescriptor(int outerLen, int blkGroupSize) static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{ {
auto ws_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(outerLen, blkGroupSize)); auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded = auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k, transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, wsPad), make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)), make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -166,19 +176,19 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -166,19 +176,19 @@ struct DeviceReduceMultiBlockPartialReduce
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -188,21 +198,21 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -188,21 +198,21 @@ struct DeviceReduceMultiBlockPartialReduce
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -210,8 +220,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -210,8 +220,7 @@ struct DeviceReduceMultiBlockPartialReduce
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 if(testBlkGroupSize <= MaxBlockGroupSize)
if(testBlkGroupSize <= 128)
break; break;
iterations++; iterations++;
...@@ -241,7 +250,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -241,7 +250,7 @@ struct DeviceReduceMultiBlockPartialReduce
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -337,18 +346,22 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -337,18 +346,22 @@ struct DeviceReduceMultiBlockPartialReduce
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -371,19 +384,19 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -371,19 +384,19 @@ struct DeviceReduceMultiBlockPartialReduce
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
...@@ -36,20 +36,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -36,20 +36,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1), static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !"); "Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -57,18 +57,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -57,18 +57,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides) const std::vector<int>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() { const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDims) if constexpr(reduceAllDim)
{ {
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -79,6 +79,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -79,6 +79,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
} }
else else
{ {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
...@@ -93,18 +96,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -93,18 +96,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
} }
}(); }();
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_M =
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
transform_tensor_descriptor(in_grid_desc_m_k, in_grid_desc_m_k,
make_tuple(make_right_pad_transform(outerLen, inPad_M), make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(innerLen, inPad_K)), make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
...@@ -112,44 +117,45 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -112,44 +117,45 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides) const std::vector<int>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = auto out_grid_desc_m_padded = transform_tensor_descriptor(
transform_tensor_descriptor(out_grid_desc_m, out_grid_desc_m,
make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int>& inLengths, Argument(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_indices_dev,
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op) const OutElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
...@@ -161,21 +167,21 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -161,21 +167,21 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
{ {
(void)workspace_dev; (void)workspace_dev;
std::tie(inLengths_, inStrides_) = inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha); alpha_ = type_convert<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths_); get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
...@@ -187,7 +193,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -187,7 +193,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
std::vector<int> outStrides_; std::vector<int> outStrides_;
AccDataType alpha_; AccDataType alpha_;
OutDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
...@@ -278,18 +284,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -278,18 +284,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(InvariantDims::Size() == 0) if constexpr(NumInvariantDim == 0)
return (false); {
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
return (false); return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0) if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false); return (false);
};
} }
else else
{ {
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) if(pArg->inStrides_[Rank - 1] != 1)
return (false); return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0) if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
...@@ -310,19 +320,19 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -310,19 +320,19 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int>& inLengths, MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int>& inStrides, const std::vector<int> inStrides,
const std::vector<int>& outLengths, const std::vector<int> outLengths,
const std::vector<int>& outStrides, const std::vector<int> outStrides,
const std::vector<int>& reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op) override const OutElementwiseOperation acc_elementwise_op) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
......
#ifndef CK_ELEMENT_WISE_OPERATION_HPP #ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP #define CK_ELEMENT_WISE_OPERATION_HPP
#include "data_type.hpp"
#include "data_type.hpp" #include "data_type.hpp"
...@@ -19,6 +18,8 @@ struct PassThrough ...@@ -19,6 +18,8 @@ struct PassThrough
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; } __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
}; };
struct Add struct Add
...@@ -239,6 +240,24 @@ struct UnaryIdentic<int32_t, int32_t, false> ...@@ -239,6 +240,24 @@ struct UnaryIdentic<int32_t, int32_t, false>
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }; __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
}; };
template <>
struct UnaryIdentic<int32_t, int32_t, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int8_t, int8_t, false>
{
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false> template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare; struct UnarySquare;
...@@ -311,6 +330,19 @@ struct UnaryAbs<double, double> ...@@ -311,6 +330,19 @@ struct UnaryAbs<double, double>
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
}; };
template <>
struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
};
template <typename Y, typename X> template <typename Y, typename X>
struct UnarySqrt; struct UnarySqrt;
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -52,23 +53,25 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -52,23 +53,25 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutElementwiseOperation acc_elementwise_op, const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
if constexpr(!NeedIndices) if constexpr(!NeedIndices)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, constexpr bool IsSecondCall = false;
out_grid_desc_m,
in_elementwise_op, GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
acc_elementwise_op, out_grid_desc_m,
alpha, in_elementwise_op,
p_in_global, acc_elementwise_op,
beta, alpha,
p_out_global, p_in_global,
p_ws_indices_global, beta,
p_indices_global); p_out_global,
p_ws_indices_global,
p_indices_global);
} }
else else
{ {
...@@ -102,23 +105,25 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -102,23 +105,25 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutElementwiseOperation acc_elementwise_op, const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
if constexpr(!NeedIndices) if constexpr(!NeedIndices)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, constexpr bool IsSecondCall = true;
out_grid_desc_m,
in_elementwise_op, GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
acc_elementwise_op, out_grid_desc_m,
alpha, in_elementwise_op,
p_in_global, acc_elementwise_op,
beta, alpha,
p_out_global, p_in_global,
p_ws_indices_global, beta,
p_indices_global); p_out_global,
p_ws_indices_global,
p_indices_global);
} }
else else
{ {
...@@ -156,6 +161,11 @@ template <typename InDataType, ...@@ -156,6 +161,11 @@ template <typename InDataType,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise struct GridwiseReduction_mk_to_m_blockwise
{ {
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -174,8 +184,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -174,8 +184,7 @@ struct GridwiseReduction_mk_to_m_blockwise
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T> using PassThroughOp = tensor_operation::element_wise::PassThrough;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -183,17 +192,24 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -183,17 +192,24 @@ struct GridwiseReduction_mk_to_m_blockwise
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op, const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType, using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
...@@ -345,7 +361,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -345,7 +361,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf); priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta); accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
}); });
}; };
}; };
...@@ -355,7 +371,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -355,7 +371,7 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -366,7 +382,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -366,7 +382,7 @@ struct GridwiseReduction_mk_to_m_blockwise
out_grid_desc_m, out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
...@@ -379,7 +395,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -379,7 +395,7 @@ struct GridwiseReduction_mk_to_m_blockwise
const OutElementwiseOperation& acc_elementwise_op, const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
...@@ -570,7 +586,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -570,7 +586,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf); priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta); accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
}); });
}; };
}; };
...@@ -580,7 +596,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -580,7 +596,7 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -591,14 +607,14 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -591,14 +607,14 @@ struct GridwiseReduction_mk_to_m_blockwise
out_grid_desc_m, out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
auto threadwise_dst_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<index_t>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -609,7 +625,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -609,7 +625,7 @@ struct GridwiseReduction_mk_to_m_blockwise
out_grid_desc_m, out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<index_t>{}); PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0), make_tuple(I0),
...@@ -631,11 +647,14 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -631,11 +647,14 @@ struct GridwiseReduction_mk_to_m_blockwise
const OutElementwiseOperation acc_elementwise_op, const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_ws_values_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType, PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
...@@ -841,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -841,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf); priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta); accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
}); });
}; };
}; };
...@@ -851,7 +870,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -851,7 +870,7 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -862,14 +881,14 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -862,14 +881,14 @@ struct GridwiseReduction_mk_to_m_blockwise
out_grid_desc_m, out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
auto threadwise_dst_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<IndexDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -880,7 +899,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -880,7 +899,7 @@ struct GridwiseReduction_mk_to_m_blockwise
out_grid_desc_m, out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<index_t>{}); PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0), make_tuple(I0),
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -84,6 +85,11 @@ template <typename InDataType, ...@@ -84,6 +85,11 @@ template <typename InDataType,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{ {
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -109,8 +115,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -109,8 +115,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
template <typename T> using PassThroughOp = tensor_operation::element_wise::PassThrough;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -249,7 +254,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -249,7 +254,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
out_grid_desc_m, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
* SOFTWARE. * SOFTWARE.
* *
*******************************************************************************/ *******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP #ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP #define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -101,6 +102,12 @@ template <typename InDataType, ...@@ -101,6 +102,12 @@ template <typename InDataType,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
{ {
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -119,8 +126,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -119,8 +126,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T> using PassThroughOp = tensor_operation::element_wise::PassThrough;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -238,9 +244,6 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -238,9 +244,6 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Each block executes multiple parallel reductions on the LDS, and due to the using of // Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions. // vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -254,6 +257,9 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -254,6 +257,9 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I)); BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
}); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_store = auto threadwise_workspace_store =
...@@ -261,7 +267,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -261,7 +267,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, WorkspaceDesc_M_K,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -273,7 +279,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -273,7 +279,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
block_local_id), block_local_id),
PassThroughOp<AccDataType>{}); PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc, threadwise_workspace_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
...@@ -450,7 +456,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -450,7 +456,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, WorkspaceDesc_M_K,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -462,14 +468,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -462,14 +468,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
block_local_id), block_local_id),
PassThroughOp<AccDataType>{}); PassThroughOp{});
auto threadwise_workspace_idx_store = auto threadwise_workspace_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, WorkspaceDesc_M_K,
PassThroughOp<IndexDataType>, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -481,7 +487,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -481,7 +487,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
block_local_id), block_local_id),
PassThroughOp<IndexDataType>{}); PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc, threadwise_workspace_val_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -50,7 +51,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -50,7 +51,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const AccElementwiseOperation acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
...@@ -101,11 +102,15 @@ template <typename InDataType, ...@@ -101,11 +102,15 @@ template <typename InDataType,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_threadwise struct GridwiseReduction_mk_to_m_threadwise
{ {
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadBufferDimAccessOrder = using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type; typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
template <typename T> using PassThroughOp = tensor_operation::element_wise::PassThrough;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -115,7 +120,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -115,7 +120,7 @@ struct GridwiseReduction_mk_to_m_threadwise
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
...@@ -228,7 +233,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -228,7 +233,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf); priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
}); });
}; };
}; };
...@@ -238,7 +243,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -238,7 +243,7 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -248,7 +253,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -248,7 +253,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false>( false>(
out_grid_desc_m, out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize), make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
...@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_threadwise
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_global,
OutDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
...@@ -387,7 +392,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -387,7 +392,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf); priorDstValue_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
}); });
}; };
}; };
...@@ -397,7 +402,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -397,7 +402,7 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<AccDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -407,14 +412,14 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -407,14 +412,14 @@ struct GridwiseReduction_mk_to_m_threadwise
false>( false>(
out_grid_desc_m, out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize), make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<AccDataType>{}); PassThroughOp{});
auto threadwise_dst_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
OutGridDesc_M, OutGridDesc_M,
PassThroughOp<IndexDataType>, PassThroughOp,
Sequence<MThreadSliceSize>, Sequence<MThreadSliceSize>,
Sequence<0>, Sequence<0>,
0, 0,
...@@ -424,7 +429,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -424,7 +429,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false>( false>(
out_grid_desc_m, out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize), make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThroughOp<IndexDataType>{}); PassThroughOp{});
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
......
...@@ -79,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -79,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
...@@ -250,6 +252,8 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -250,6 +252,8 @@ struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -313,7 +317,8 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -313,7 +317,8 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i]; dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}); });
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
...@@ -439,6 +444,10 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -439,6 +444,10 @@ struct ThreadwiseTensorSliceTransfer_v3
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -1016,7 +1025,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1016,7 +1025,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, "wrong!"); static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
......
...@@ -637,19 +637,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -637,19 +637,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
......
...@@ -606,6 +606,12 @@ struct sequence_map_inverse ...@@ -606,6 +606,12 @@ struct sequence_map_inverse
SeqMap::Size()>::type; SeqMap::Size()>::type;
}; };
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
{
return ((Xs == Ys) && ...);
}
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{ {
......
...@@ -37,6 +37,10 @@ struct SpaceFillingCurve ...@@ -37,6 +37,10 @@ struct SpaceFillingCurve
__host__ __device__ static constexpr index_t GetNumOfAccess() __host__ __device__ static constexpr index_t GetNumOfAccess()
{ {
static_assert(TensorLengths::Size() == ScalarsPerAccess::Size());
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::Size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) /
ScalarPerVector; ScalarPerVector;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment