Unverified Commit 63eee2d9 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Overhaul to Reducton and its dependants (#237)

* Tiny fix in dynamic_buffer.hpp to support vectorized AtomicAdd for double type

* Update to host layer and host reduction

* Merge and remove reduction kernels

* Merge and remove reduction device interfaces and update pooling device interface

* Merge and remove useless reduction device instances

* Update to reduction profiler and reduction ctests

* Update to reduction and pooling examples and add one reduction example

* Change to reduction examples to let them testable by ctest

* Add explicit pass checking for reduction and pooling examples

* Explicit assignment of tensor shapes in example reduce_blockwise_two_call

* Use atomic_add to repace atomicAdd and add atomic_add for double type

* Add reduce ctest support for double data type

* Replace to_int_vector() by using c++ std::vector::assign()

* Keep DeviceReduceThreadWise separated from DeviceReduceBlockWise

* Merge DeviceReduceBlockWise and DeviceReduceMultiBlockAtomicAdd into DeviceReduceMultiBlock

* Add GetAtomicOperationZeroValue() support for AtomicMax

* Tiny change to reduce example README.md

* Fix some tiny issues due to branch merging

* Revoke previous change in dynamic_buffer.hpp and add atomic_add for double2_t

* Add reduce multiblock_atomic_add instances for fp64 to verify vectorized atomic_add on fp64

* Renaming

* Clean the header includings in device_reduce instances header files
parent 1085794d
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp -D 16,64,32,960 -v 1 1 10) add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp)
...@@ -5,23 +5,38 @@ ...@@ -5,23 +5,38 @@
# -D <xxx> : input 4-d tensor lengths # -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes) # -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) #arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: run kernel # of times (>1) #arg2: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10 ./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
``` ```
Result Result
``` ```
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up Warm up 1 time
Start running 3 times... Start running 10 times...
Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0 ```
max_diff: 0, 529, 529
root@dc-smc-18:/data/composable_kernel/Build3# bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10 # Instructions for ```example_reduce_blockwise_two_call```
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up ## Run ```example_reduce_blockwise_two_call```
```bash
#arg1: verification (0=no, 1=yes(
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise_two_call 1 2 1
Result
```
./bin/example_reduce_blockwise_two_call 1 2 1
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times... Start running 10 times...
Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0
max_diff: 0, 528, 528
``` ```
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_multiblock.hpp"
#include "host_reduce_util.hpp" #include "host_common_util.hpp"
#include "host_reduction.hpp" #include "host_reduction.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
...@@ -30,9 +30,8 @@ constexpr int Rank = 4; ...@@ -30,9 +30,8 @@ constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr NanPropagation NanOpt = NanPropagation::PROPAGATE_NAN; constexpr bool PropagateNan = true;
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; constexpr bool OutputIndex = false;
constexpr ReduceTensorIndices IndicesOpt = ReduceTensorIndices::NO_INDICES;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
...@@ -40,85 +39,44 @@ using InElementwiseOperation = ...@@ -40,85 +39,44 @@ using InElementwiseOperation =
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance = DeviceReduceBlockWise<InDataType, using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, InMemoryDataOperationEnum::Set,
false, PropagateNan,
256, OutputIndex,
4, false, // HaveIndexInputIfOutputIndex
64, 256,
1, 4,
1, 64,
0, 1,
1, 1,
1>; 0,
1,
1>;
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"scales", required_argument, nullptr, 'S'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'}, {"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}}; {nullptr, 0, nullptr, 0}};
class SimpleAppArgs class SimpleAppArgs
{ {
template <typename T>
static T getSingleValueFromString(const std::string& valueStr)
{
std::istringstream iss(valueStr);
T ret;
iss >> ret;
return (ret);
};
template <typename T>
static std::vector<T> getTypeValuesFromString(const char* cstr_values)
{
std::string valuesStr(cstr_values);
std::vector<T> values;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = valuesStr.find(',', pos);
while(new_pos != std::string::npos)
{
const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
pos = new_pos + 1;
new_pos = valuesStr.find(',', pos);
};
std::string sliceStr = valuesStr.substr(pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
return (values);
};
private: private:
int option_index = 0; int option_index = 0;
public: public:
std::vector<size_t> inLengths; std::vector<size_t> inLengths = {16, 64, 32, 960};
std::vector<float> scales; std::vector<float> scales = {1.0f, 0.0f};
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -126,24 +84,24 @@ class SimpleAppArgs ...@@ -126,24 +84,24 @@ class SimpleAppArgs
std::cout << "Usage of " << cmd << std::endl; std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl; << std::endl;
std::cout << "--scales or -S, comma separated two float values for alpha and beta"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction" "comparing with the host-based reduction"
<< std::endl; << std::endl;
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg2 -- time kernel (0=n0, 1=yes)" << std::endl; std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
{ {
using ck::host_common::getTypeValuesFromString;
int ch; int ch;
while(1) while(1)
{ {
ch = getopt_long(argc, argv, "D:S:v:l:", long_options, &option_index); ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
if(ch == -1) if(ch == -1)
break; break;
switch(ch) switch(ch)
...@@ -154,12 +112,6 @@ class SimpleAppArgs ...@@ -154,12 +112,6 @@ class SimpleAppArgs
inLengths = getTypeValuesFromString<size_t>(optarg); inLengths = getTypeValuesFromString<size_t>(optarg);
break; break;
case 'S':
if(!optarg)
throw std::runtime_error("Invalid option format!");
scales = getTypeValuesFromString<float>(optarg);
break;
case 'v': case 'v':
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
...@@ -181,7 +133,7 @@ class SimpleAppArgs ...@@ -181,7 +133,7 @@ class SimpleAppArgs
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
time_kernel = std::atoi(argv[optind]); time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(scales.empty()) if(scales.empty())
{ {
...@@ -202,16 +154,16 @@ int main(int argc, char* argv[]) ...@@ -202,16 +154,16 @@ int main(int argc, char* argv[])
SimpleAppArgs args; SimpleAppArgs args;
if(args.processArgs(argc, argv) < 0) if(argc > 1)
return (-1); {
if(args.processArgs(argc, argv) < 0)
return (-1);
};
constexpr bool op_support_indices = constexpr bool op_support_indices =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices =
(op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES));
// if input is half type, no reason to use float for indiced reduction operation and must use // if input is half type, no reason to use float for indiced reduction operation and must use
// float for non-indiced reduction operation for accuracy // float for non-indiced reduction operation for accuracy
constexpr bool invalid_reduce_1 = constexpr bool invalid_reduce_1 =
...@@ -225,8 +177,7 @@ int main(int argc, char* argv[]) ...@@ -225,8 +177,7 @@ int main(int argc, char* argv[])
(op_support_indices && !std::is_same<AccDataType, float>::value); (op_support_indices && !std::is_same<AccDataType, float>::value);
// indices option can only be used when it is really needed // indices option can only be used when it is really needed
constexpr bool invalid_reduce_3 = constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex);
(!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES);
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3); constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3);
...@@ -294,9 +245,9 @@ int main(int argc, char* argv[]) ...@@ -294,9 +245,9 @@ int main(int argc, char* argv[])
if(beta != 0.0f) if(beta != 0.0f)
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.mData.data());
size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
DeviceMem out_indices_dev(indicesSizeInBytes); DeviceMem out_index_dev(indicesSizeInBytes);
if(args.do_verification) if(args.do_verification)
{ {
...@@ -307,38 +258,39 @@ int main(int argc, char* argv[]) ...@@ -307,38 +258,39 @@ int main(int argc, char* argv[])
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
NeedIndices> OutputIndex>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run( hostReduce.Run(
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
}; };
const auto i_inLengths = to_int_vector(args.inLengths); std::vector<ck::index_t> i_inLengths;
const auto i_inStrides = to_int_vector(inStrides); std::vector<ck::index_t> i_inStrides;
const auto i_outLengths = to_int_vector(outLengths); std::vector<ck::index_t> i_outLengths;
const auto i_outStrides = to_int_vector(outStrides); std::vector<ck::index_t> i_outStrides;
i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
i_inStrides.assign(inStrides.begin(), inStrides.end());
i_outLengths.assign(outLengths.begin(), outLengths.end());
i_outStrides.assign(outStrides.begin(), outStrides.end());
auto reduce = DeviceReduceInstance{}; auto reduce = DeviceReduceInstance{};
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims); auto argument_ptr = reduce.MakeArgumentPointer(
i_inLengths,
DeviceMem ws_dev(wsSizeInBytes); i_inStrides,
i_outLengths,
auto argument_ptr = i_outStrides,
reduce.MakeArgumentPointer(i_inLengths, reduceDims,
i_inStrides, alpha,
i_outLengths, beta,
i_outStrides, in_dev.GetDeviceBuffer(),
reduceDims, nullptr,
alpha, out_dev.GetDeviceBuffer(),
beta, out_index_dev.GetDeviceBuffer(),
in_dev.GetDeviceBuffer(), InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
out_dev.GetDeviceBuffer(), AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
out_indices_dev.GetDeviceBuffer(),
ws_dev.GetDeviceBuffer(),
InElementwiseOperation{static_cast<int>(reduce_total_length)},
AccElementwiseOperation{static_cast<int>(reduce_total_length)});
if(!reduce.IsSupportedArgument(argument_ptr.get())) if(!reduce.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -362,16 +314,18 @@ int main(int argc, char* argv[]) ...@@ -362,16 +314,18 @@ int main(int argc, char* argv[])
<< std::endl; << std::endl;
bool pass = true; bool pass = true;
if(args.do_verification) if(args.do_verification)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
pass &= ck::utils::check_err(out.mData, out_ref.mData); pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
if(NeedIndices) if(OutputIndex)
{ {
out_indices_dev.FromDevice(out_indices.mData.data()); out_index_dev.FromDevice(out_indices.mData.data());
pass &= ck::utils::check_err(out_indices.mData, out_indices_ref.mData); pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
}; };
}; };
return pass ? 0 : 1;
return (pass ? 0 : 1);
} }
#include <iostream>
#include <numeric>
#include <sstream>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "check_err.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_reduce_multiblock.hpp"
#include "host_common_util.hpp"
#include "host_reduction.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
using InOutDataType = ck::half_t;
using InOutDataType = ck::half_t;
using AccDataType = float;
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<AccDataType, AccDataType>;
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
AccDataType,
InOutDataType,
5, // Rank
1, // NumReduceDim
ReduceOperation,
InElementwiseOperation,
PassThroughOp,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
32,
8,
1,
1,
1, // vector dim
1,
1>;
using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
AccDataType,
InOutDataType,
4, // Rank
1, // NumReduceDim
ReduceOperation,
PassThroughOp,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
128,
2,
1,
1,
1, // vector dim
1,
1>;
static bool do_verify;
static int init_method;
static float alpha;
static float beta;
static bool time_kernel;
int main(int argc, char* argv[])
{
// used by the device reduction
const std::vector<int> reduceDims_1 = {4};
const std::vector<int> invariantDims_1 = {0, 1, 2, 3};
const std::vector<int> reduceDims_2 = {3};
const std::vector<int> invariantDims_2 = {0, 1, 2};
// used by the host reduction
const std::vector<int> reduceDims = {3, 4};
const std::vector<int> invariantDims = {0, 1, 2};
const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128};
// input lengths of the second reduction, which is also the output lengths of the first
// reduction
const std::vector<size_t> inLengths_2 = {64, 320, 80, 4};
const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1)
{
do_verify = true;
init_method = 2;
time_kernel = true;
}
else if(argc == 4)
{
do_verify = static_cast<bool>(argv[1]);
init_method = atoi(argv[2]);
time_kernel = static_cast<bool>(atoi(argv[3]));
}
else
{
std::ostringstream ostr;
ostr << "Wrong parameter! " << std::endl
<< "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl;
throw std::runtime_error(ostr.str());
};
alpha = 1.0f;
beta = 0.0f;
Tensor<InOutDataType> in_1(inLengths_1);
Tensor<InOutDataType> out_ref(outLengths);
Tensor<InOutDataType> in_2(inLengths_2); // also the output tensor of the first reduction
Tensor<InOutDataType> out(outLengths);
auto inStrides_1 = in_1.mDesc.GetStrides();
auto inStrides_2 = in_2.mDesc.GetStrides();
auto outStrides = out.mDesc.GetStrides();
size_t invariant_total_length = out.mDesc.GetElementSize();
size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length;
std::size_t num_thread = 1;
if(do_verify)
{
switch(init_method)
{
case 0: break;
case 1:
in_1.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
break;
case 2:
in_1.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
break;
default:
in_1.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0},
num_thread);
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
out.mData[i] = out_ref.mData[i];
};
DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace());
DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace());
in_1_dev.ToDevice(in_1.mData.data());
if(beta != 0.0f)
out_dev.ToDevice(out.mData.data());
if(do_verify)
{
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOpId,
5, // Rank
2, // NumReduceDim
PropagateNan,
OutputIndex>
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr);
};
std::vector<ck::index_t> i_inLengths_1;
std::vector<ck::index_t> i_inStrides_1;
std::vector<ck::index_t> i_inLengths_2;
std::vector<ck::index_t> i_inStrides_2;
std::vector<ck::index_t> i_outLengths;
std::vector<ck::index_t> i_outStrides;
i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end());
i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end());
i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end());
i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end());
i_outLengths.assign(outLengths.begin(), outLengths.end());
i_outStrides.assign(outStrides.begin(), outStrides.end());
auto reduce_1 = DeviceReduceInstance_1{};
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(
i_inLengths_1,
i_inStrides_1,
i_inLengths_2,
i_inStrides_2,
reduceDims_1,
1.0f,
0.0f,
in_1_dev.GetDeviceBuffer(),
nullptr,
in_2_dev.GetDeviceBuffer(),
nullptr,
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
PassThroughOp{});
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
};
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
auto reduce_2 = DeviceReduceInstance_2{};
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(
i_inLengths_2,
i_inStrides_2,
i_outLengths,
i_outStrides,
reduceDims_2,
alpha,
beta,
in_2_dev.GetDeviceBuffer(),
nullptr,
out_dev.GetDeviceBuffer(),
nullptr,
PassThroughOp{},
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
{
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
};
auto invoker_ptr_2 = reduce_2.MakeInvokerPointer();
float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel});
float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
invariant_total_length * sizeof(InOutDataType);
float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2);
std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, "
<< reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl;
bool pass = true;
if(do_verify)
{
out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
};
return (pass ? 0 : 1);
}
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) #arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: run kernel # of times (>1) #arg3: time kernel (0=no, 1=yes)
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx #arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
./bin/example_pool2d_fwd 1 1 10 ./bin/example_pool2d_fwd 1 1 1
``` ```
Result Result
...@@ -14,9 +14,7 @@ Result ...@@ -14,9 +14,7 @@ Result
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
Warm up Warm up 1 time
Start running 10 times... Start running 10 times...
Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s
error: 0
max_diff: 0, 1, 1
``` ```
...@@ -20,6 +20,8 @@ using InDataType = ck::half_t; ...@@ -20,6 +20,8 @@ using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using IndexDataType = int32_t;
using InLayout = ck::tensor_layout::convolution::NHWC; using InLayout = ck::tensor_layout::convolution::NHWC;
using OutLayout = ck::tensor_layout::convolution::NHWC; using OutLayout = ck::tensor_layout::convolution::NHWC;
...@@ -29,7 +31,7 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; ...@@ -29,7 +31,7 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
#endif #endif
static constexpr bool NeedIndices = false; static constexpr bool OutputIndex = false;
static constexpr bool PropagateNan = false; static constexpr bool PropagateNan = false;
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
...@@ -38,7 +40,7 @@ using DevicePoolFwdInstance = ...@@ -38,7 +40,7 @@ using DevicePoolFwdInstance =
OutDataType, // OutDataType OutDataType, // OutDataType
AccDataType, // AccDataType AccDataType, // AccDataType
ReduceOpId, ReduceOpId,
NeedIndices, OutputIndex,
64, // BlockSize 64, // BlockSize
64, // ReduceMThreadClusterSize 64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize 1, // ReduceKThreadClusterSize
...@@ -51,10 +53,10 @@ template <typename InDataType, ...@@ -51,10 +53,10 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool PropagateNan, bool PropagateNan,
bool NeedIndices> bool OutputIndex>
static void pool_host_verify(const Tensor<InDataType>& in, static void pool_host_verify(const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
Tensor<int>& out_indices, Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, 2>& window_spatial_lengths, const std::array<ck::index_t, 2>& window_spatial_lengths,
const std::array<ck::index_t, 2>& window_strides, const std::array<ck::index_t, 2>& window_strides,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
...@@ -62,26 +64,26 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -62,26 +64,26 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{ {
using namespace ck::host_reduce; using namespace ck::host_reduce;
const int divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider); const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider); const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
for(int y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
int hi = ho * window_strides[0] + y - in_left_pads[0]; ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(int x = 0; x < window_spatial_lengths[1]; ++x) for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{ {
int wi = wo * window_strides[1] + x - in_left_pads[1]; ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < ck::type_convert<int>(in.mDesc.GetLengths()[2]) && wi >= 0 && if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi < ck::type_convert<int>(in.mDesc.GetLengths()[3])) wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
...@@ -108,24 +110,24 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -108,24 +110,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>(); auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
int accuIndex = 0; IndexDataType accuIndex = 0;
for(int y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
int hi = ho * window_strides[0] + y - in_left_pads[0]; ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(int x = 0; x < window_spatial_lengths[1]; ++x) for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{ {
int wi = wo * window_strides[1] + x - in_left_pads[1]; ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < in.mDesc.GetLengths()[3])
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
int currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal); PreUnaryOp(currVal);
binop_with_nan_check2<AccDataType, PropagateNan>( binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce, accuVal, currVal, accuIndex, currIndex); opReduce, accuVal, currVal, accuIndex, currIndex);
} }
} }
...@@ -149,9 +151,9 @@ int main(int argc, char* argv[]) ...@@ -149,9 +151,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce; using namespace ck::host_reduce;
bool do_verification = true; bool do_verification;
int init_method = 1; int init_method;
bool time_kernel = false; bool time_kernel;
// Pool shape // Pool shape
ck::index_t N = 128; ck::index_t N = 128;
...@@ -167,17 +169,23 @@ int main(int argc, char* argv[]) ...@@ -167,17 +169,23 @@ int main(int argc, char* argv[])
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 1;
if(argc == 4) if(argc == 1)
{
do_verification = true;
init_method = 1;
time_kernel = true;
}
else if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = static_cast<bool>(std::stoi(argv[3]));
} }
else if(argc == 16) else if(argc == 16)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = static_cast<bool>(std::stoi(argv[3]));
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
C = std::stoi(argv[5]); C = std::stoi(argv[5]);
...@@ -196,7 +204,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +204,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
...@@ -228,9 +236,11 @@ int main(int argc, char* argv[]) ...@@ -228,9 +236,11 @@ int main(int argc, char* argv[])
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
Tensor<int> out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<IndexDataType> out_indices_n_c_ho_wo_host(
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
Tensor<int> out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<IndexDataType> out_indices_n_c_ho_wo_device(
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl;
...@@ -245,25 +255,25 @@ int main(int argc, char* argv[]) ...@@ -245,25 +255,25 @@ int main(int argc, char* argv[])
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace());
DeviceMem out_indices_device_buf(sizeof(int) * DeviceMem out_indices_device_buf(sizeof(IndexDataType) *
out_indices_n_c_ho_wo_device.mDesc.GetElementSpace()); out_indices_n_c_ho_wo_device.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
auto pool = DevicePoolFwdInstance{}; auto pool = DevicePoolFwdInstance{};
auto invoker_ptr = pool.MakeInvokerPointer(); auto invoker_ptr = pool.MakeInvokerPointer();
auto argument_ptr = auto argument_ptr = pool.MakeArgumentPointer(
pool.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<int*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
N, N,
C, C,
std::array<ck::index_t, 2>{{Hi, Wi}}, std::array<ck::index_t, 2>{{Hi, Wi}},
std::array<ck::index_t, 2>{{Y, X}}, std::array<ck::index_t, 2>{{Y, X}},
std::array<ck::index_t, 2>{{Ho, Wo}}, std::array<ck::index_t, 2>{{Ho, Wo}},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
if(!pool.IsSupportedArgument(argument_ptr.get())) if(!pool.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -286,6 +296,7 @@ int main(int argc, char* argv[]) ...@@ -286,6 +296,7 @@ int main(int argc, char* argv[])
<< std::endl; << std::endl;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
pool_host_verify<InDataType, pool_host_verify<InDataType,
...@@ -293,7 +304,7 @@ int main(int argc, char* argv[]) ...@@ -293,7 +304,7 @@ int main(int argc, char* argv[])
AccDataType, AccDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
NeedIndices>(in_n_c_hi_wi, OutputIndex>(in_n_c_hi_wi,
out_n_c_ho_wo_host, out_n_c_ho_wo_host,
out_indices_n_c_ho_wo_host, out_indices_n_c_ho_wo_host,
window_spatial_lengths, window_spatial_lengths,
...@@ -303,15 +314,16 @@ int main(int argc, char* argv[]) ...@@ -303,15 +314,16 @@ int main(int argc, char* argv[])
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
pass &= ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
if constexpr(NeedIndices) if constexpr(OutputIndex)
{ {
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
pass &= ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
out_indices_n_c_ho_wo_host.mData); out_indices_n_c_ho_wo_host.mData);
}; };
} }
return pass ? 0 : 1;
return (pass ? 0 : 1);
} }
...@@ -17,7 +17,7 @@ template <typename InDataType, ...@@ -17,7 +17,7 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool NeedIndices, bool OuputIndex,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize, ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize, ck::index_t ReduceKThreadClusterSize,
...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation; AccElementwiseOperation;
static constexpr bool BetaIsZero = true;
static constexpr index_t InSrcOutDstVectorDim = static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced. // not reduced.
...@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce =
OutDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
AccDataType, OutDataType,
IndexDataType, AccDataType,
AGridDesc_M_K, IndexDataType,
BGridDesc_M, AGridDesc_M_K,
ReduceOperation, BGridDesc_M,
InElementwiseOperation, ReduceOperation,
AccElementwiseOperation, InElementwiseOperation,
false, // propagate_nan AccElementwiseOperation,
BetaIsZero, InMemoryDataOperationEnum::Set,
BlockSize, false, // propagate_nan
ReduceMThreadClusterSize, BlockSize,
ReduceKThreadClusterSize, ReduceMThreadSliceSize,
ReduceMThreadSliceSize, ReduceKThreadSliceSize,
ReduceKThreadSliceSize, InSrcOutDstVectorDim,
InSrcOutDstVectorDim, InSrcOutDstVectorSize,
InSrcOutDstVectorSize, InSrcOutDstVectorSize>;
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce, const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
NeedIndices, OuputIndex,
false, // don't have index input
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.acc_element_op_, arg.acc_element_op_,
float(1), float(1),
arg.p_in_dev_, arg.p_in_dev_,
nullptr,
float(0), float(0),
arg.p_out_dev_, arg.p_out_dev_,
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
......
...@@ -16,35 +16,18 @@ namespace device { ...@@ -16,35 +16,18 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{
(void)inLengths;
(void)reduceDims;
return (0);
};
virtual bool HasFurtherCall() { return (false); };
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
{
(void)argPtr;
return (std::vector<int>{0, 0});
};
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) = 0; const AccElementwiseOperation acc_elementwise_op) = 0;
......
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_BLOCKWISE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
nullptr,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWiseSecondCall
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static_assert(
std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1];
invariant_lowest_length = inLengths[0];
reduce_lowest_length = inLengths[1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
arg.workspace_indices_dev_,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
(void)reduceDims;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -14,13 +14,13 @@ namespace device { ...@@ -14,13 +14,13 @@ namespace device {
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// of reduce dims // of reduce dims
template <int Rank, int NumReduceDim> template <index_t Rank, int NumReduceDim>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t invariant_total_length = 1; long_index_t invariant_total_length = 1;
size_t reduce_total_length = 1; long_index_t reduce_total_length = 1;
constexpr int NumInvariantDim = Rank - NumReduceDim; constexpr int NumInvariantDim = Rank - NumReduceDim;
...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) ...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
{ {
return make_tuple(static_cast<index_t>(lengths[Ns])...); return make_tuple(static_cast<index_t>(lengths[Ns])...);
}; };
template <index_t arraySize> template <index_t arraySize>
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>) auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
{ {
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
}; };
template <index_t Rank, index_t NumReduceDim> template <index_t Rank, index_t NumReduceDim>
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides, std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
const std::vector<int>& reduceDims) const std::vector<int>& reduceDims)
{ {
std::vector<int> newLengthsStrides; std::vector<index_t> newLengthsStrides;
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #define DEVICE_REDUCE_MULTIBLOCK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp" #include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_set_buffer_value.hpp" #include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -22,8 +23,10 @@ template <typename InDataType, ...@@ -22,8 +23,10 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,8 +35,7 @@ template <typename InDataType, ...@@ -32,8 +35,7 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlockAtomicAdd struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd = // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value; std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(!NeedIndices && support_AtomicAdd, static_assert(
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when " !use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"having float/double output type!"); "The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
static_assert(
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides, const std::vector<index_t>& inStrides,
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M = const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(length, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
in_index_dev_{in_index_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; if constexpr(use_multiblock)
while(true)
{ {
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 int iterations = 1;
if(testBlkGroupSize <= 128) while(true)
break; {
int testBlkGroupSize =
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
iterations++; // we want the blkGroupSize be not more than 128
}; if(testBlkGroupSize <= 128)
break;
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / iterations++;
(K_BlockTileSize * iterations); };
kBlockTileIterations = iterations; blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration =
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize; M_BlockTileSize * blkGroupSize;
...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
const IndexDataType* in_index_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
index_t blkGroupSize; int blkGroupSize;
index_t kBlockTileIterations; int numBlockTileIteration;
size_t gridSize; size_t gridSize;
size_t gridSize_pre; size_t gridSize_pre;
...@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor( const auto out_grid_desc_m =
DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
arg.outLengths_, arg.outStrides_); arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce =
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0; using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
OutMemoryDataOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
OutputIndex,
HaveIndexInput,
InDataType,
OutDataType,
AccDataType,
int32_t,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>; float avg_time = 0;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time += launch_and_time_kernel(stream_config, if constexpr(use_multiblock)
kernel_pre, {
dim3(arg.gridSize_pre), const auto zeroVal =
dim3(BlockSize), ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
0, OutMemoryDataOperation);
out_grid_desc_m,
arg.out_dev_, const auto kernel_pre =
static_cast<OutDataType>(0.0f)); kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
};
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_main, kernel_main,
...@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd
arg.in_elementwise_op_, arg.in_elementwise_op_,
arg.acc_elementwise_op_, arg.acc_elementwise_op_,
arg.blkGroupSize, arg.blkGroupSize,
arg.kBlockTileIterations, arg.numBlockTileIteration,
arg.alpha_, arg.alpha_,
arg.in_dev_, arg.in_dev_,
arg.out_dev_); arg.in_index_dev_,
arg.beta_,
arg.out_dev_,
arg.out_index_dev_);
return avg_time; return (avg_time);
} };
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock)
{
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
};
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
return (false); return (false);
}; };
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
// To improve // To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// cases with small reduce_total_length should be handled by the BlockWise method if constexpr(use_multiblock)
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) {
return (false); // blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if(pArg->blkGroupSize == 1)
return (false);
// This is very strong restriction, but needed to avoid some failure // This is very strong restriction, but needed to avoid some failure
if(pArg->invariant_lowest_length % M_BlockTileSize != 0) if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
return (false); return (false);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
};
return (true); return (true);
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
...@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<const IndexDataType*>(in_index_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlockPartialReduce
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{
size_t invariant_total_length;
size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
long_index_t workspace_size = invariant_total_length * blkGroupSize;
long_index_t wsSizeInBytes =
!NeedIndices
? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes);
};
bool HasFurtherCall() override { return (true); };
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{
auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (ws_desc_m_k_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
kBlockTileIterations = iterations;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<int*>(
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
AccDataType* workspace_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t blkGroupSize;
index_t kBlockTileIterations;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
arg.invariant_total_length, arg.blkGroupSize);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
using GridwiseReduce =
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
NeedIndices,
InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
ws_desc_m_k,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.in_dev_,
arg.workspace_dev_,
arg.workspace_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(OutDstVectorSize != 1)
return (false);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// cases with small reduce_total_length should be handled by the BlockWise method
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
return (false);
return (true);
};
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return (
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_2d_reduction_threadwise.hpp" #include "gridwise_2d_reduction_threadwise.hpp"
namespace ck { namespace ck {
...@@ -19,22 +20,19 @@ template <typename InDataType, ...@@ -19,22 +20,19 @@ template <typename InDataType,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename AccElementwiseOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation> struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides) const std::vector<index_t>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_indices_dev_{out_indices_dev}, out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_indices_dev_; IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize; size_t gridSize;
}; };
...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using InGridDesc_M_K = decltype(in_grid_desc_m_k); using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m); using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0; float avg_time = 0;
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<GridwiseReduce, const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
NeedIndices, OutputIndex,
HaveIndexInput,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InGridDesc_M_K, InGridDesc_M_K,
OutGridDesc_M, OutGridDesc_M,
InElementwiseOperation, InElementwiseOperation,
OutElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config, avg_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
arg.acc_elementwise_op_, arg.acc_elementwise_op_,
arg.alpha_, arg.alpha_,
arg.in_dev_, arg.in_dev_,
nullptr,
arg.beta_, arg.beta_,
arg.out_dev_, arg.out_dev_,
arg.out_indices_dev_); arg.out_index_dev_);
return (avg_time); return (avg_time);
}; };
...@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// TODO: remove this. Should return true, as long as this DeviceOP instance support this // cases with big reduce_total_length should be handled by Blockwise kernel
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
// better performance
if(pArg->reduce_total_length / KThreadSliceSize >= 32) if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false); return (false);
...@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
(void)in_index_dev;
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
outLengths, outLengths,
...@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
...@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReducceThreadWise<" << BlockSize << ","; str << "DeviceReduceThreadWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = false;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = true;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
bool PropagateNan,
bool BetaIsZero,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
__device__ static void
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
src_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
src_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
};
} // namespace ck
#endif
...@@ -23,75 +23,86 @@ ...@@ -23,75 +23,86 @@
* SOFTWARE. * SOFTWARE.
* *
*******************************************************************************/ *******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp" #include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation>
__global__ void __global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m,
const WorkspaceDesc_M_K workspace_desc_m_k, const InElementwiseOperation in_elementwise_op,
const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, index_t block_group_size,
index_t block_group_size, index_t num_k_block_tile_iteration,
index_t num_k_block_tile_iteration, AccDataType alpha,
const InDataType* const __restrict__ p_src_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType* const __restrict__ p_ws_values_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
(void)p_in_index_global;
(void)p_out_index_global;
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, block_group_size,
num_k_block_tile_iteration, num_k_block_tile_iteration,
p_src_global, alpha,
p_ws_values_global, p_in_value_global,
p_ws_indices_global); beta,
p_out_value_global);
} }
else else
{ {
GridwiseReduction::RunWithIndex(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, num_k_block_tile_iteration,
num_k_block_tile_iteration, alpha,
p_src_global, p_in_value_global,
p_ws_values_global, p_in_index_global,
p_ws_indices_global); beta,
p_out_value_global,
p_out_index_global);
}; };
}; };
template <typename InDataType, template <typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -101,14 +112,13 @@ template <typename InDataType, ...@@ -101,14 +112,13 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce struct GridwiseReduction_mk_to_m_multiblock
{ {
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size, index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
// Each block executes multiple parallel reductions on the LDS, and due to the using of constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}( static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_store = if(block_group_size == 0 && !float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc, threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_buf); out_global_val_buf);
} }
}; };
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType, PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
(void)acc_elementwise_op; (void)in_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto zeroVal = ReduceOperation::GetReductionZeroVal();
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, false>(
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, in_grid_desc_m_k,
block_local_id * reduceSizePerBlock + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t indexOffset = block_local_id * reduceSizePerBlock;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
index_t reducedTiles = 0; constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values index_t reducedTiles = 0;
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation if constexpr(HaveIndexInput)
in_elementwise_op(in_thread_val_buf(Number<offset>{}), {
in_thread_val_buf(Number<offset>{})); auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
AccDataType tmpValue = zeroVal; threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
IndexDataType tmpIndex = 0; threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue, reducedTiles++;
in_thread_val_buf[Number<offset>{}], } while(reducedTiles < num_k_block_tile_iteration);
tmpIndex, }
in_thread_idx_buf[Number<offset>{}]); else
{
index_t indexOffset = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
BlockwiseReduceWithIndex::Reduce( threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
};
indexOffset += K_BlockTileSize; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
reducedTiles++; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
} while(reducedTiles < num_k_block_tile_iteration); if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( accu_value_buf(I) *= alpha;
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); }
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_val_store = if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
auto threadwise_workspace_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_val_buf); out_global_val_buf);
threadwise_workspace_idx_store.Run(reduced_data_desc, threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_index_buf, accu_index_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_idx_buf); out_global_idx_buf);
} }
}; };
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_global,
p_out_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
};
} // namespace ck
#endif
...@@ -37,7 +37,8 @@ ...@@ -37,7 +37,8 @@
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, beta,
p_out_global, p_out_value_global);
p_indices_global);
} }
else else
{ {
GridwiseReduction::RunWithIndices(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, p_in_index_global,
p_out_global, beta,
p_indices_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -91,11 +93,9 @@ template <typename InDataType, ...@@ -91,11 +93,9 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool BetaIsZero,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global)
IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType, using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
(void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0; index_t reducedLength = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, true>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
true>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; dst_global_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
dst_global_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_store = auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, OutDataType,
OutDataType, decltype(reduced_data_desc),
decltype(reduced_data_desc), OutGridDesc_M,
OutGridDesc_M, PassThroughOp,
PassThroughOp, Sequence<MThreadSliceSize>,
Sequence<MThreadSliceSize>, Sequence<0>,
Sequence<0>, 0,
0, OutDstVectorSize,
OutDstVectorSize, OutMemoryDataOperation,
InMemoryDataOperationEnum::Set, 1,
1, false>(
false>( out_grid_desc_m,
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize),
make_multi_index(thread_global_1d_id * MThreadSliceSize), PassThroughOp{});
PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k, template <bool HaveIndexInput>
const OutGridDesc_M& out_grid_desc_m, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const InElementwiseOperation& in_elementwise_op, const OutGridDesc_M& out_grid_desc_m,
const AccElementwiseOperation& acc_elementwise_op, const InElementwiseOperation& in_elementwise_op,
AccDataType alpha, const AccElementwiseOperation& acc_elementwise_op,
const InDataType* const __restrict__ p_in_global, AccDataType alpha,
AccDataType beta, const InDataType* const __restrict__ p_in_value_global,
OutDataType* const __restrict__ p_out_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t indexStart = 0; index_t indexStart = 0;
index_t reducedLength = 0; index_t reducedLength = 0;
do if constexpr(HaveIndexInput)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, auto threadwise_src_idx_load =
in_global_buf, ThreadwiseTensorSliceTransfer_v2<IndexDataType,
thread_buffer_desc, IndexDataType,
make_tuple(I0, I0), InGridDesc_M_K,
in_thread_val_buf); decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { ThreadwiseReduceWithIndex::Reduce(
// do element-wise pre-reduction operation in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK(); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
in_elementwise_op(in_thread_val_buf(Number<offset>{}), indexStart += KThreadSliceSize;
in_thread_val_buf(Number<offset>{})); reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
}
else
{
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
}); });
});
ThreadwiseReduceWithIndex::Reduce( ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
};
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, false>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
false>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; out_global_val_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
out_global_val_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -325,7 +325,7 @@ struct DynamicBuffer ...@@ -325,7 +325,7 @@ struct DynamicBuffer
{ {
if(is_valid_element) if(is_valid_element)
{ {
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x); atomic_add(c_style_pointer_cast<X*>(&p_data_[i]), x);
} }
} }
} }
......
...@@ -28,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x) ...@@ -28,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
return atomicAdd(p_dst, x); return atomicAdd(p_dst, x);
} }
template <>
__device__ double atomic_add<double>(double* p_dst, const double& x)
{
return atomicAdd(p_dst, x);
}
template <> template <>
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
{ {
...@@ -45,6 +51,23 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) ...@@ -45,6 +51,23 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return vy.template AsType<float2_t>()[I0]; return vy.template AsType<float2_t>()[I0];
} }
template <>
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const vector_type<double, 2> vx{x};
vector_type<double, 2> vy{0};
vy.template AsType<double>()(I0) =
atomicAdd(c_style_pointer_cast<double*>(p_dst), vx.template AsType<double>()[I0]);
vy.template AsType<double>()(I1) =
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, vx.template AsType<double>()[I1]);
return vy.template AsType<double2_t>()[I0];
}
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for // instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment