Commit cab8f2e5 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parents c20aabc3 9a17e7fb
...@@ -26,11 +26,10 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -26,11 +26,10 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int32_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = int32_t;
...@@ -50,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -50,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
CLayout, // CLayout CLayout, // CLayout
PassThrough, // AElementwiseOperation PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // BElementwiseOperation
RequantReluRequant, // CElementwiseOperation PassThrough, // CElementwiseOperation
256, // BlockSize 256, // BlockSize
256, // MPerBlock 256, // MPerBlock
128, // NPerBlock 128, // NPerBlock
...@@ -78,11 +77,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -78,11 +77,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -99,9 +98,6 @@ int main(int argc, char* argv[]) ...@@ -99,9 +98,6 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
float scale_gemm = 0.03;
float scale_relu = 1;
if(argc == 4) if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
...@@ -175,7 +171,7 @@ int main(int argc, char* argv[]) ...@@ -175,7 +171,7 @@ int main(int argc, char* argv[])
auto a_element_op = PassThrough{}; auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{}; auto b_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu}; auto c_element_op = PassThrough{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
......
# Instructions for ```reduce_blockwise``` Example
## Docker script
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```reduce_blockwise```
```bash
mkdir build && cd build
```
```bash
# Need to specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j reduce_blockwise
```
## Run ```reduce_blockwise```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=integer value, 2=decimal value)
#arg2: run kernel # of times (>1)
./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
```
Result
```
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 3 times...
Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0
max_diff: 0, 529, 529
root@dc-smc-18:/data/composable_kernel/Build3# bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 10 times...
Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0
max_diff: 0, 528, 528
```
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "device_reduce_blockwise.hpp" #include "device_reduce_blockwise.hpp"
#include "host_reduce_util.hpp" #include "host_reduce_util.hpp"
#include "host_generic_reduction.hpp" #include "host_generic_reduction.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
...@@ -29,7 +30,7 @@ using kOutDataType = ck::half_t; ...@@ -29,7 +30,7 @@ using kOutDataType = ck::half_t;
using kAccDataType = float; using kAccDataType = float;
constexpr int Rank = 4; constexpr int Rank = 4;
using ReduceDims_ = ck::Sequence<0, 1, 2>; constexpr int NumReduceDim = 3;
constexpr ReduceTensorOp_t ReduceOpId = ReduceTensorOp_t::NORM2; constexpr ReduceTensorOp_t ReduceOpId = ReduceTensorOp_t::NORM2;
constexpr NanPropagation_t NanOpt = NanPropagation_t::PROPAGATE_NAN; constexpr NanPropagation_t NanOpt = NanPropagation_t::PROPAGATE_NAN;
...@@ -46,7 +47,7 @@ using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType, ...@@ -46,7 +47,7 @@ using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType,
kAccDataType, kAccDataType,
kOutDataType, kOutDataType,
Rank, Rank,
ReduceDims_, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -192,39 +193,13 @@ class SimpleAppArgs ...@@ -192,39 +193,13 @@ class SimpleAppArgs
}; };
}; };
template <int Rank, typename ReduceDims>
static std::vector<int> get_reduce_dims()
{
std::vector<int> resDims;
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
return (resDims);
};
template <int Rank, typename ReduceDims>
static std::vector<int> get_invariant_dims()
{
std::vector<int> resDims;
unsigned int incFlag = 0;
static_for<0, ReduceDims::Size(), 1>{}(
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
for(int dim = 0; dim < Rank; dim++)
{
if(incFlag & (0x1 << dim))
continue;
resDims.push_back(dim);
};
return (resDims);
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce; using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3};
SimpleAppArgs args; SimpleAppArgs args;
if(args.processArgs(argc, argv) < 0) if(args.processArgs(argc, argv) < 0)
...@@ -260,15 +235,12 @@ int main(int argc, char* argv[]) ...@@ -260,15 +235,12 @@ int main(int argc, char* argv[])
Tensor<InDataType> in(args.inLengths); Tensor<InDataType> in(args.inLengths);
const std::vector<int> InvariantDims = get_invariant_dims<Rank, ReduceDims_>();
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
std::vector<size_t> outLengths; std::vector<size_t> outLengths;
if(InvariantDims.empty()) if(invariantDims.empty())
outLengths.push_back(1); outLengths.push_back(1);
else else
for(auto dim : InvariantDims) for(auto dim : invariantDims)
outLengths.push_back(args.inLengths[dim]); outLengths.push_back(args.inLengths[dim]);
Tensor<OutDataType> out_ref(outLengths); Tensor<OutDataType> out_ref(outLengths);
...@@ -328,7 +300,7 @@ int main(int argc, char* argv[]) ...@@ -328,7 +300,7 @@ int main(int argc, char* argv[])
if(args.do_verification) if(args.do_verification)
{ {
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices> ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices>
hostReduce(in.mDesc, out_ref.mDesc, InvariantDims, ReduceDims); hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run( hostReduce.Run(
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
...@@ -350,6 +322,7 @@ int main(int argc, char* argv[]) ...@@ -350,6 +322,7 @@ int main(int argc, char* argv[])
i_inStrides, i_inStrides,
i_outLengths, i_outLengths,
i_outStrides, i_outStrides,
reduceDims,
alpha, alpha,
beta, beta,
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
......
# Instructions for ```gemm_xdl``` Example # Instructions for ```pool2d_fwd``` Example
## Docker script ## Docker script
```bash ```bash
...@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \ ...@@ -13,7 +13,7 @@ rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash /bin/bash
``` ```
## Build ```gemm_xdl``` ## Build ```pool2d_fwd```
```bash ```bash
mkdir build && cd build mkdir build && cd build
``` ```
...@@ -30,27 +30,26 @@ cmake \ ...@@ -30,27 +30,26 @@ cmake \
``` ```
```bash ```bash
make -j gemm_xdl make -j pool2d_fwd
``` ```
## Run ```gemm_xdl``` ## Run ```pool2d_fwd```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
./example/gemm_xdl 0 1 5 #arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
./example/pool2d_fwd 1 1 10
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result
``` ```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up Warm up
Start running 5 times... Start running 10 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s
error: 0
max_diff: 0, 1, 1
``` ```
add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp)
\ No newline at end of file
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using ShuffleDataType = int32_t;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
ADataType, // ADataType
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
ShuffleDataType, // ShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation
RequantReluRequant, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
float scale_gemm = 0.03;
float scale_relu = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_result);
}
return 0;
}
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization_t::MNPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
exit(0);
}
int group_count = 4;
// GEMM shape
std::vector<ck::GemmShape> gemm_shapes;
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
b_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize()));
c_tensors_device.push_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
gemm_shapes[i].p_a = a_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_b = b_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_c = c_tensors_device[i]->GetDeviceBuffer();
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
check_error(c_host_tensors[i], c_device_tensors[i]);
}
}
return 0;
}
...@@ -38,4 +38,5 @@ add_subdirectory(10_conv2d_bwd_data) ...@@ -38,4 +38,5 @@ add_subdirectory(10_conv2d_bwd_data)
add_subdirectory(11_conv2d_bwd_wgt) add_subdirectory(11_conv2d_bwd_wgt)
add_subdirectory(12_reduce) add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd) add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_grouped_gemm) add_subdirectory(14_gemm_xdl_requant_relu_requant)
add_subdirectory(15_grouped_gemm)
...@@ -32,57 +32,53 @@ ...@@ -32,57 +32,53 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "cluster_descriptor.hpp"
namespace ck { namespace ck {
template <typename Buffer1dDescType, template <typename AccDataType,
typename AccDataType,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, typename ThreadClusterLengths_M_K,
index_t KThreadClusterSize, typename ThreadClusterArrangeOrder,
bool ReorderThreadClusters,
typename OpReduce, typename OpReduce,
bool PropagateNan> bool PropagateNan>
struct PartitionedBlockwiseReductionOn1dBuffer struct PartitionedBlockwiseReduction
{ {
static constexpr auto buffer_1d_desc = Buffer1dDescType{}; static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"The product of cluster lengths should be same as BlockSize!"); "The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize, static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
"The buffer size should be the same as BlockSize!"); static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>; using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType> template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer, __device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
AccDataType& accuData,
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
{ {
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset) if(thread_k_cluster_id < indOffset)
{ {
// consider the thread clusters order, ensure the contiguous locations are accessed index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
// by contiguous Thread-ID index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
index_t offset1 = make_tuple(0, indOffset));
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]); AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]); AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
...@@ -93,34 +89,34 @@ struct PartitionedBlockwiseReductionOn1dBuffer ...@@ -93,34 +89,34 @@ struct PartitionedBlockwiseReductionOn1dBuffer
__syncthreads(); __syncthreads();
}); });
index_t offset = ReorderThreadClusters index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
accuData = type_convert<AccDataType>(block_buffer[offset]); accuData = type_convert<AccDataType>(block_buffer[offset]);
}; };
}; };
template <typename Buffer1dDescType, template <typename AccDataType,
typename AccDataType,
typename IndexDataType, typename IndexDataType,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, typename ThreadClusterLengths_M_K,
index_t KThreadClusterSize, typename ThreadClusterArrangeOrder,
bool ReorderThreadClusters,
typename OpReduce, typename OpReduce,
bool PropagateNan> bool PropagateNan>
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer struct PartitionedBlockwiseReductionWithIndex
{ {
static constexpr auto buffer_1d_desc = Buffer1dDescType{}; static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"The product of cluster lengths should be same as BlockSize!"); "The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize, static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
"The buffer size should be the same as BlockSize!"); static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation = using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>; detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
...@@ -130,32 +126,24 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer ...@@ -130,32 +126,24 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
__device__ static void Reduce(BufferType& block_val_buffer, __device__ static void Reduce(BufferType& block_val_buffer,
IdxBufferType& block_idx_buffer, IdxBufferType& block_idx_buffer,
AccDataType& accuData, AccDataType& accuData,
IndexDataType& accuIndex, IndexDataType& accuIndex)
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
{ {
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I(); constexpr index_t indOffset = 1 << I();
if(thread_k_cluster_id % (indOffset * 2) == 0) if(thread_k_cluster_id % (indOffset * 2) == 0)
{ {
// consider the thread clusters order, ensure the contiguous locations are accessed index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
// by contiguous Thread-ID index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
index_t offset1 = make_tuple(0, indOffset));
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]); AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]); AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
...@@ -170,10 +158,7 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer ...@@ -170,10 +158,7 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
__syncthreads(); __syncthreads();
}); });
index_t offset = ReorderThreadClusters index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
accuData = type_convert<AccDataType>(block_val_buffer[offset]); accuData = type_convert<AccDataType>(block_val_buffer[offset]);
accuIndex = block_idx_buffer[offset]; accuIndex = block_idx_buffer[offset];
......
...@@ -36,14 +36,15 @@ struct DeviceReduce : public BaseOperator ...@@ -36,14 +36,15 @@ struct DeviceReduce : public BaseOperator
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_indices_dev,
void* workspace_dev, void* workspace_dev,
const InElementwiseOperation& inElementwiseOp, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& accElementwiseOp) = 0; const AccElementwiseOperation& acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -15,8 +15,8 @@ namespace device { ...@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, index_t Rank,
typename ReduceDims, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -40,7 +40,12 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -40,7 +40,12 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>()); static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
...@@ -74,7 +79,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -74,7 +79,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
} }
else else
{ {
const auto toReduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
...@@ -82,7 +87,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -82,7 +87,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(invariantDimLengths), make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)), make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}), make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -136,6 +141,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -136,6 +141,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -144,30 +150,31 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -144,30 +150,31 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev} : outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{ {
(void)workspace_dev; (void)workspace_dev;
inLengths_ = inLengths; std::tie(inLengths_, inStrides_) =
inStrides_ = inStrides; shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha); alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths); get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
...@@ -305,6 +312,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -305,6 +312,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -318,6 +326,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl ...@@ -318,6 +326,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
inStrides, inStrides,
outLengths, outLengths,
outStrides, outStrides,
reduceDims,
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
......
...@@ -15,8 +15,8 @@ namespace device { ...@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, index_t Rank,
typename ReduceDims, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -45,7 +45,11 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -45,7 +45,11 @@ struct DeviceReduceBlockWiseSecondCall
std::is_same<InDataType, AccDataType>::value, std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"); "InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>()); static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
...@@ -117,16 +121,16 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -117,16 +121,16 @@ struct DeviceReduceBlockWiseSecondCall
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev} : inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{ {
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha); alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = static_cast<OutDataType>(beta);
...@@ -268,6 +272,7 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -268,6 +272,7 @@ struct DeviceReduceBlockWiseSecondCall
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -277,6 +282,8 @@ struct DeviceReduceBlockWiseSecondCall ...@@ -277,6 +282,8 @@ struct DeviceReduceBlockWiseSecondCall
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override const AccElementwiseOperation& acc_elementwise_op) override
{ {
(void)reduceDims;
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
outLengths, outLengths,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_REDUCE_COMMON_HPP #define DEVICE_REDUCE_COMMON_HPP
#include <vector> #include <vector>
#include <cassert>
#include "common_header.hpp" #include "common_header.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
...@@ -40,23 +41,6 @@ constexpr bool belong() ...@@ -40,23 +41,6 @@ constexpr bool belong()
return (inside); return (inside);
}; };
template <int Rank, typename ReduceDims, int start = 0>
constexpr auto get_invariant_dims()
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
if constexpr(start >= Rank)
return Sequence<>{};
else
{
if constexpr(!belong<start, ReduceDims>())
return merge_sequences(Sequence<start>{},
get_invariant_dims<Rank, ReduceDims, start + 1>());
else
return get_invariant_dims<Rank, ReduceDims, start + 1>();
};
};
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
...@@ -74,6 +58,45 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -74,6 +58,45 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
return make_tuple_from_array_and_index_seq(lengths, index_seq); return make_tuple_from_array_and_index_seq(lengths, index_seq);
}; };
template <index_t Rank, index_t NumReduceDim>
static inline std::pair<std::vector<int>, std::vector<int>>
shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
const std::vector<int>& dimStrides,
const std::vector<int>& reduceDims)
{
std::vector<int> newDimLengths;
std::vector<int> newDimStrides;
assert(Rank == dimLengths.size() && Rank == dimStrides.size() &&
NumReduceDim == reduceDims.size());
int reduceFlag = 0;
// flag the bits for the reduceDims
for(int i = 0; i < NumReduceDim; i++)
{
reduceFlag |= 1 << reduceDims[i];
};
// collect invariant dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0)
{
newDimLengths.push_back(dimLengths[i]);
newDimStrides.push_back(dimStrides[i]);
};
// collect reduce dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) > 0)
{
newDimLengths.push_back(dimLengths[i]);
newDimStrides.push_back(dimStrides[i]);
};
return std::make_pair(newDimLengths, newDimStrides);
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -17,8 +17,8 @@ namespace device { ...@@ -17,8 +17,8 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, index_t Rank,
typename ReduceDims, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t; using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>()); static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
...@@ -84,7 +89,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -84,7 +89,7 @@ struct DeviceReduceMultiBlockAtomicAdd
} }
else else
{ {
const auto toReduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
...@@ -92,7 +97,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -92,7 +97,7 @@ struct DeviceReduceMultiBlockAtomicAdd
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(invariantDimLengths), make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)), make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}), make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -147,6 +152,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -147,6 +152,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -155,31 +161,31 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -155,31 +161,31 @@ struct DeviceReduceMultiBlockAtomicAdd
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev} : outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{ {
(void)out_indices_dev; (void)out_indices_dev;
(void)workspace_dev; (void)workspace_dev;
inLengths_ = inLengths; std::tie(inLengths_, inStrides_) =
inStrides_ = inStrides; shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha); alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths); get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -369,6 +375,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -369,6 +375,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -382,6 +389,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -382,6 +389,7 @@ struct DeviceReduceMultiBlockAtomicAdd
inStrides, inStrides,
outLengths, outLengths,
outStrides, outStrides,
reduceDims,
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
......
...@@ -15,8 +15,8 @@ namespace device { ...@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, index_t Rank,
typename ReduceDims, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockPartialReduce
using IndexDataType = int32_t; using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>()); static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
...@@ -112,7 +117,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -112,7 +117,7 @@ struct DeviceReduceMultiBlockPartialReduce
} }
else else
{ {
const auto toReduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
...@@ -120,7 +125,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -120,7 +125,7 @@ struct DeviceReduceMultiBlockPartialReduce
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(invariantDimLengths), make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)), make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}), make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -161,10 +166,11 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -161,10 +166,11 @@ struct DeviceReduceMultiBlockPartialReduce
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t>& inLengths, Argument(const std::vector<int>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<int>& inStrides,
const std::vector<index_t>& outLengths, const std::vector<int>& outLengths,
const std::vector<index_t>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -173,31 +179,30 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -173,31 +179,30 @@ struct DeviceReduceMultiBlockPartialReduce
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, : outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_indices_dev_{out_indices_dev}, out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev} workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{ {
inLengths_ = inLengths; std::tie(inLengths_, inStrides_) =
inStrides_ = inStrides; shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha); alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths); get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
int iterations = 1; int iterations = 1;
while(true) while(true)
...@@ -370,6 +375,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -370,6 +375,7 @@ struct DeviceReduceMultiBlockPartialReduce
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -383,6 +389,7 @@ struct DeviceReduceMultiBlockPartialReduce ...@@ -383,6 +389,7 @@ struct DeviceReduceMultiBlockPartialReduce
inStrides, inStrides,
outLengths, outLengths,
outStrides, outStrides,
reduceDims,
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
......
...@@ -16,7 +16,7 @@ template <typename InDataType, ...@@ -16,7 +16,7 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
index_t Rank, index_t Rank,
typename ReduceDims, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
...@@ -40,7 +40,12 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -40,7 +40,12 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool BetaIsZero = NeedIndices;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>()); static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank; static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
...@@ -74,7 +79,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -74,7 +79,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
} }
else else
{ {
const auto toReduceDimLengths = const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
...@@ -82,7 +87,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -82,7 +87,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(invariantDimLengths), make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)), make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}), make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -136,6 +141,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -136,6 +141,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -144,30 +150,32 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -144,30 +150,32 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
AccDataType* workspace_dev, AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op) const OutElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev} : outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{ {
(void)workspace_dev; (void)workspace_dev;
inLengths_ = inLengths; std::tie(inLengths_, inStrides_) =
inStrides_ = inStrides; shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha); alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta); beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths); get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0) if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1; invariant_lowest_length = 1;
else else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)]; invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)]; reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
...@@ -306,6 +314,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -306,6 +314,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const std::vector<int>& inStrides, const std::vector<int>& inStrides,
const std::vector<int>& outLengths, const std::vector<int>& outLengths,
const std::vector<int>& outStrides, const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -319,6 +328,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -319,6 +328,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
inStrides, inStrides,
outLengths, outLengths,
outStrides, outStrides,
reduceDims,
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
namespace ck { namespace ck {
...@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc = using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T> template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>; using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -180,12 +194,10 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -180,12 +194,10 @@ struct GridwiseReduction_mk_to_m_blockwise
const IndexDataType* const __restrict__ p_ws_indices_global, const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc), using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
AccDataType,
BlockSize, BlockSize,
MThreadClusterSize, ThreadClusterLengths_M_K,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using Accumulation = using Accumulation =
...@@ -221,28 +233,28 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -221,28 +233,28 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>(in_grid_desc_m_k, false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
...@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster) block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I]; accu_value_buf[I];
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
__syncthreads(); __syncthreads();
BlockwiseReduce::Reduce( BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -380,13 +385,11 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -380,13 +385,11 @@ struct GridwiseReduction_mk_to_m_blockwise
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc), PartitionedBlockwiseReductionWithIndex<AccDataType,
AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
MThreadClusterSize, ThreadClusterLengths_M_K,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -432,28 +435,28 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -432,28 +435,28 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>(in_grid_desc_m_k, false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
...@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise
}); });
// store thread local value to LDS for parallel reduction // store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster) block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{ tmpValue;
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize + block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
thread_m_cluster_id) = tmpValue; tmpIndex;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads(); __syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf, BlockwiseReduceWithIndex::Reduce(
block_reduce_idx_buf, block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
...@@ -648,13 +637,11 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -648,13 +637,11 @@ struct GridwiseReduction_mk_to_m_blockwise
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc), PartitionedBlockwiseReductionWithIndex<AccDataType,
AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
MThreadClusterSize, Sequence<MThreadClusterSize, KThreadClusterSize>,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -707,43 +694,45 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -707,43 +694,45 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_val_load =
InDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>(in_grid_desc_m_k, false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_idx_load =
IndexDataType, ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType, IndexDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>(in_grid_desc_m_k, false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
...@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise
}); });
// store thread local value to LDS for parallel reduction // store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster) block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{ tmpValue;
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize + block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
thread_m_cluster_id) = tmpValue; tmpIndex;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads(); __syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf, BlockwiseReduceWithIndex::Reduce(
block_reduce_idx_buf, block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
......
...@@ -86,15 +86,26 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -86,15 +86,26 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{ {
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc = using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc), using ThreadBufferDimAccessOrder =
AccDataType, typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize, BlockSize,
MThreadClusterSize, ThreadClusterLengths_M_K,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -102,6 +113,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -102,6 +113,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>; using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size; const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size; const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
...@@ -158,13 +170,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -158,13 +170,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
...@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
// consistent reduced result for that invariant dimension. due to the using of vector_load, // consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions. // each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster) block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I]; accu_value_buf[I];
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
__syncthreads(); __syncthreads();
blockwise_reduce::Reduce( BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......
...@@ -30,8 +30,8 @@ ...@@ -30,8 +30,8 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
namespace ck { namespace ck {
...@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
{ {
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer1dDesc = using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T> template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>; using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -124,12 +138,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -124,12 +138,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType* const __restrict__ p_ws_values_global, AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global) IndexDataType* const __restrict__ p_ws_indices_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc), using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
AccDataType,
BlockSize, BlockSize,
MThreadClusterSize, ThreadClusterLengths_M_K,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size; const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size; const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
...@@ -181,13 +193,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -181,13 +193,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
...@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
// Each block executes multiple parallel reductions on the LDS, and due to the using of // Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions. // vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster) block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I]; accu_value_buf[I];
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
__syncthreads(); __syncthreads();
BlockwiseReduce::Reduce( BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
}); });
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -290,13 +294,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -290,13 +294,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
IndexDataType* const __restrict__ p_ws_indices_global) IndexDataType* const __restrict__ p_ws_indices_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc), PartitionedBlockwiseReductionWithIndex<AccDataType,
AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
MThreadClusterSize, ThreadClusterLengths_M_K,
KThreadClusterSize, ThreadClusterArrangeOrder,
reorder_thread_cluster,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size; const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size; const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize const auto thread_cluster_idx =
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize) const auto thread_m_cluster_id = thread_cluster_idx[I0];
: thread_local_id % KThreadClusterSize; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
...@@ -359,13 +361,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -359,13 +361,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2< auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
AccDataType, AccDataType,
InGridDesc_M_K, InGridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type, ThreadBufferDimAccessOrder,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
...@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
}); });
// store thread local value to LDS for parallel reduction // store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster) block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
{ tmpValue;
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize + block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
thread_m_cluster_id) = tmpValue; tmpIndex;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
__syncthreads(); __syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf, BlockwiseReduceWithIndex::Reduce(
block_reduce_idx_buf, block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment