Unverified Commit 5c7cec11 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Code clean up (#20)



* tuning para,

* testing on v100

* add fp16

* remove deprecated tensor descriptor

* sync with miopen

* update build script
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7d09790a
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using namespace ck;
template <typename T, class InDesc, class WeiDesc, class OutDesc, class LeftPads, class RightPads>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
LeftPads,
RightPads,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif
#if 1 // debug
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
LeftPads,
RightPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#endif
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
make_ParallelTensorFunctor(
[&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); },
N,
C,
Hi,
Wi)(std::thread::hardware_concurrency());
// convert wei_kcyx to wei_cyxk
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
make_ParallelTensorFunctor(
[&](auto k, auto c, auto y, auto x) { wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); },
K,
C,
Y,
X)(std::thread::hardware_concurrency());
// conver out_nkhw to out_knhw
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 256;
#endif
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
// mem
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead +
BPerBlock)); // reserve extra space for BGhostRead
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
#else
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// convert out_khwn to out_nkhw
make_ParallelTensorFunctor(
[&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); },
N,
K,
Ho,
Wo)(std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_C_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_C_N1_B_N2 = Sequence<8, 2, 16, 1>;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_C_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_C_K = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
#endif
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_C_N1_B_N2,
InBlockCopyClusterLengths_C_N1_B_N2,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K,
WeiBlockCopyDataPerAccess_K>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once #pragma once
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "convolution_common.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template <typename T, template <typename T,
...@@ -28,6 +27,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -28,6 +27,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
{ {
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -55,25 +56,105 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -55,25 +56,105 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 0
// BlockSize = 256, EperBlock = 8, each thread hold 64 data // cdata = 64, BlockSize = 256, 64x256x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t KPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t BPerBlock = 32;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 32, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
...@@ -91,25 +172,27 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -91,25 +172,27 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 16; constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
...@@ -128,26 +211,28 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -128,26 +211,28 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1 #elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data // cdata = 4, BlockSize = 256, 128x128x16
// for 1x1 // for 1x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 16; constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
...@@ -166,25 +251,261 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -166,25 +251,261 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1 #elif 0
// BlockSize = 64, each thread hold 64 data // cdata = 64, BlockSize = 128, 64x128x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 128, 64x128x8
constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 128, 64x128x16
constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x4
constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 8; constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 8, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x8
constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x16
constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
#elif 0
// cdata = 64, BlockSize = 64, 64x64x8
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
...@@ -204,24 +525,221 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -204,24 +525,221 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 0
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data // cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 3;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 3;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 64, 64x64x3
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 8;
constexpr index_t EPerBlock = 3;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 1>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 4>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x4
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x8
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// cdata = 32, BlockSize = 256, 64x128x8
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 8; constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 2; constexpr index_t GemmDataPerReadA = 2;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
...@@ -243,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -243,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
#endif #endif
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThread;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2); constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
...@@ -252,72 +770,76 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -252,72 +770,76 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< GridSize,
GridSize, BlockSize,
BlockSize, T,
T, T,
T, decltype(in_nchw_desc),
decltype(in_nchw_desc), decltype(wei_kcyx_desc),
decltype(wei_kcyx_desc), decltype(out_nkhw_desc),
decltype(out_nkhw_desc), ConvStrides,
ConvStrides, ConvDilations,
ConvDilations, LeftPads,
LeftPads, RightPads,
RightPads, BPerBlock,
ConvolutionDirection::Forward, KPerBlock,
BPerBlock, EPerBlock,
KPerBlock, GemmNRepeat,
EPerBlock, GemmMPerThread,
GemmNRepeat, GemmNPerThread,
GemmMPerThreadSubC, GemmKPerThread,
GemmNPerThreadSubC, GemmMLevel0Cluster,
GemmMLevel0Cluster, GemmNLevel0Cluster,
GemmNLevel0Cluster, GemmMLevel1Cluster,
GemmMLevel1Cluster, GemmNLevel1Cluster,
GemmNLevel1Cluster, GemmDataPerReadA,
GemmKPerThreadLoop, GemmDataPerReadB,
GemmDataPerReadA, InBlockCopySubLengths_E_N1_B_N2,
GemmDataPerReadB, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopySubLengths_E_N1_B_N2, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcDataPerRead_B,
InBlockCopyDstAccessOrder, InBlockCopyDstDataPerWrite_N2,
InBlockCopySrcDataPerRead_B, WeiBlockCopySubLengths_E_K,
InBlockCopyDstDataPerWrite_N2, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstDataPerWrite_K>;
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{}; for(index_t i = 0; i < 5; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = std::cout << "Start running " << nrepeat << " times..." << std::endl;
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
const T* const __restrict__, KernelTimer timer;
const T* const __restrict__, timer.Start();
T* const __restrict__>,
dim3(GridSize), for(index_t j = 0; j < nrepeat; ++j)
dim3(BlockSize), {
0, launch_kernel(run_gridwise_operation<gridwise_conv,
0, const TDevice* const __restrict__,
gridwise_conv, const TDevice* const __restrict__,
const_cast<const T* const __restrict__>( TDevice* const __restrict__>,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())), dim3(GridSize),
const_cast<const T* const __restrict__>( dim3(BlockSize),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())), 0,
const_cast<T* const __restrict__>( 0,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()))); static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
printf("Elapsed time : %f ms, %f TFlop/s\n", static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
time, }
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time); timer.End();
usleep(std::min(time * 1000, float(10000)));
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
} }
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
constexpr index_t BlockSize = 64;
constexpr index_t BPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 2;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
constexpr index_t BlockSize = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 16>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#endif
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
ConvolutionDirection::Forward,
BPerBlock,
KPerBlock,
EPerBlock,
GemmNRepeat,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
using namespace ck;
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// 1x1 filter, 8x8 image
constexpr index_t N0 = 1;
constexpr index_t Ho0 = 2;
constexpr index_t Wo0 = 1;
constexpr index_t N2 = 4;
constexpr index_t Ho2 = 1;
constexpr index_t Wo2 = 1;
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 4, 1, 1>;
using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 2, 1, 16, 1, 1, 1>;
using InBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
using InBlockCopySrcAccessOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
using InBlockCopyDstAccessOrder =
Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2]
constexpr index_t InBlockCopyDataPerAccess_W2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// 1x1 filter, 8x8 image
constexpr index_t N0 = 1;
constexpr index_t Ho0 = 2;
constexpr index_t Wo0 = 1;
constexpr index_t N2 = 2;
constexpr index_t Ho2 = 2;
constexpr index_t Wo2 = 1;
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 2, 1, 1, 2, 1, 1>;
using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 1, 1, 16, 1, 2, 1>;
using InBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
using InBlockCopySrcAccessOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
using InBlockCopyDstAccessOrder =
Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2]
constexpr index_t InBlockCopyDataPerAccess_W2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif
constexpr index_t N1 = N / (N0 * N2);
constexpr index_t Ho1 = Ho / (Ho0 * Ho2);
constexpr index_t Wo1 = Wo / (Wo0 * Wo2);
constexpr index_t B = N1 * Ho1 * Wo1;
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
N1,
N2,
Ho1,
Ho2,
Wo1,
Wo2,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_W2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp"
using namespace ck;
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
// 1x1 filter, 8x8 image
constexpr index_t N1 = 2;
constexpr index_t Ho1 = 1;
constexpr index_t Wo1 = 1;
constexpr index_t N2 = 1;
constexpr index_t Ho2 = 1;
constexpr index_t Wo2 = 4;
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2 = Sequence<8, 2, 1, 1, 16, 1, 1, 1>;
using InBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N1, N2, Ho1, Ho2, Wo1, B, Wo2]
using InBlockCopySrcAccessOrder =
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N1, N2, Ho1, Ho2, Wo1, B, Wo2]
using InBlockCopyDstAccessOrder =
Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N1, Ho1, Wo1, B, N2, Ho2, Wo2]
constexpr index_t InBlockCopyDataPerAccess_W2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t Ho0 = Ho / (Ho1 * Ho2);
constexpr index_t Wo0 = Wo / (Wo1 * Wo2);
constexpr index_t B = N0 * Ho0 * Wo0;
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
N0,
N1,
N2,
Ho0,
Ho1,
Ho2,
Wo0,
Wo1,
Wo2,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_W2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class T,
...@@ -27,6 +26,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -27,6 +26,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
{ {
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -54,20 +55,88 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -54,20 +55,88 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 0
// BlockSize = 256, GemmKPerBlock = 8 // cdata = 64, BlockSize = 256, 64x256x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 256>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -85,20 +154,56 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -85,20 +154,56 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, GemmKPerBlock = 16 // cdata = 64, BlockSize = 256, 128x128x8
// vector 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -116,21 +221,24 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -116,21 +221,24 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, GemmKPerBlock = 8 // cdata = 64, BlockSize = 256, 128x128x8
// for 1x1 filter, vector-read-b = 4 // GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -147,22 +255,25 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -147,22 +255,25 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 0
// BlockSize = 256, GemmKPerBlock = 16 // cdata = 64, BlockSize = 256, 128x128x16
// for 1x1 filter, vector-read-b = 4 // GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -179,37 +290,674 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -179,37 +290,674 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 0
// 1x1 filter, 14x14 image // cdata = 64, BlockSize = 128, 128x64x4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 2
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x8
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 16>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 128x64x16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 128, 64x128x4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// BlockSize = 128, 64x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 64x128x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// BlockSize = 128, 64x128x8
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
// BlockSize = 128, 64x128x16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 64x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 64x64x8
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 32;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 64, 64x64x3
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 64, 32x128x8
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 32, BlockSize = 128, 32x128x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// cdata = 32, BlockSize = 128, 32x128x16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t GemmM = K; constexpr index_t GemmM = K;
...@@ -220,11 +968,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -220,11 +968,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
T, TDevice,
T, TDevice,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
...@@ -235,13 +983,13 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -235,13 +983,13 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
ThreadGemmDataPerReadM, ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN, ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
...@@ -252,25 +1000,38 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -252,25 +1000,38 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float time = std::cout << "Start running " << nrepeat << " times..." << std::endl;
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), KernelTimer timer;
dim3(BlockSize), timer.Start();
0,
0, for(index_t j = 0; j < nrepeat; ++j)
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), {
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), launch_kernel(run_gridwise_operation<gridwise_conv,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); const TDevice* const __restrict__,
const TDevice* const __restrict__,
printf("Elapsed time : %f ms, %f TFlop/s\n", TDevice* const __restrict__>,
time, dim3(GridSize),
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / dim3(BlockSize),
(std::size_t(1000) * 1000 * 1000) / time); 0,
usleep(std::min(time * 1000, float(10000))); 0,
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
} }
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
using namespace ck;
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
ck::index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
#elif 1
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<1, 4>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 2>;
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 2;
#endif
constexpr index_t B = N * Ho * Wo;
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_B,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
using namespace ck;
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& in_nchw,
WeiDesc,
const Tensor<TInWei>& wei_kcyx,
OutDesc,
Tensor<TOut>& out_nkhw,
index_t nrepeat)
{
// this suppose in / wei data type is int8x4
constexpr index_t NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
ostream_ConstantTensorDescriptor(in_nchw_vec_desc, std::cout << "in_nchw_vec_desc: ");
Tensor<vector_mem_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc));
auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) {
#if 0
in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w);
#elif 0
in_nchw_vec(n, c, h, w) =
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
#elif 1
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
in_nchw(n, 4 * c + 1, h, w),
in_nchw(n, 4 * c + 2, h, w),
in_nchw(n, 4 * c + 3, h, w));
#endif
};
make_ParallelTensorFunctor(f_vectorized_nchw, N, C / NVector, Hi, Wi)(
std::thread::hardware_concurrency());
// vectorize weight
auto wei_kcyx_vec_desc = make_ConstantTensorDescriptor(Sequence<K, C / NVector, Y, X>{});
ostream_ConstantTensorDescriptor(wei_kcyx_vec_desc, std::cout << "wei_kcyx_vec_desc: ");
Tensor<vector_mem_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc));
auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) {
#if 0
wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x);
#elif 0
wei_kcyx_vec(k, c, y, x) =
vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x));
#elif 1
wei_kcyx_vec(k, c, y, x) = vector_t::Pack(wei_kcyx(k, 4 * c, y, x),
wei_kcyx(k, 4 * c + 1, y, x),
wei_kcyx(k, 4 * c + 2, y, x),
wei_kcyx(k, 4 * c + 3, y, x));
#endif
};
make_ParallelTensorFunctor(f_vectorized_kcyx, K, C / NVector, Y, X)(
std::thread::hardware_concurrency());
//
DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace());
DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(sizeof(TOut) * out_nkhw.mDesc.GetElementSpace());
in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data());
wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 16;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_and_time_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
TOut,
accum_t,
decltype(in_nchw_vec_desc),
decltype(wei_kcyx_vec_desc),
decltype(out_nkhw_desc),
NVector,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
BlockSize,
GridSize>,
dim3(GridSize),
dim3(BlockSize),
static_cast<TInWei*>(in_nchw_vec_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_kcyx_vec_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once #pragma once
#include "tensor.hpp" #include "host_tensor.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
template <typename ConstTensorDesc, std::size_t... Is> template <typename TensorDesc, std::size_t... Is>
auto make_TensorDescriptor_impl(ConstTensorDesc, std::integer_sequence<std::size_t, Is...>) auto make_HostTensorDescriptor_impl(TensorDesc, std::integer_sequence<std::size_t, Is...>)
{ {
std::initializer_list<std::size_t> lengths = {ConstTensorDesc::GetLengths()[Is]...}; std::initializer_list<std::size_t> lengths = {TensorDesc::GetLengths()[Is]...};
std::initializer_list<std::size_t> strides = {ConstTensorDesc::GetStrides()[Is]...}; std::initializer_list<std::size_t> strides = {TensorDesc::GetStrides()[Is]...};
return TensorDescriptor(lengths, strides); return HostTensorDescriptor(lengths, strides);
} }
template <typename ConstTensorDesc> template <typename TensorDesc>
auto make_TensorDescriptor(ConstTensorDesc) auto make_HostTensorDescriptor(TensorDesc)
{ {
return make_TensorDescriptor_impl( return make_HostTensorDescriptor_impl(
ConstTensorDesc{}, TensorDesc{}, std::make_integer_sequence<std::size_t, TensorDesc::GetNumOfDimension()>{});
std::make_integer_sequence<std::size_t, ConstTensorDesc::GetNumOfDimension()>{});
} }
template <typename ConstTensorDesc> template <typename TensorDesc>
void ostream_ConstantTensorDescriptor(ConstTensorDesc, std::ostream& os = std::cout) void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
{ {
ostream_TensorDescriptor(make_TensorDescriptor(ConstTensorDesc{}), os); ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os);
} }
#pragma once
#include "tensor.hpp"
template <typename T,
typename FilterSizes,
typename OutputSizes,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
void host_col2im(const Tensor<T>& in_eb,
Tensor<T>& in_nchw,
FilterSizes,
OutputSizes,
ConvStrides,
ConvDilations,
LeftPads,
RightPads)
{
using namespace ck;
int N = in_nchw.mDesc.GetLengths()[0];
int C = in_nchw.mDesc.GetLengths()[1];
int HI = in_nchw.mDesc.GetLengths()[2];
int WI = in_nchw.mDesc.GetLengths()[3];
int Y = FilterSizes{}[0];
int X = FilterSizes{}[1];
int HO = OutputSizes{}[0];
int WO = OutputSizes{}[1];
auto f = [&](auto n, auto c, auto hi, auto wi) {
double v = 0;
for(int y = 0; y < Y; ++y)
{
int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0];
if(h_tmp >= 0 && h_tmp < HI && h_tmp % ConvStrides{}[0] == 0)
{
int ho = h_tmp / ConvStrides{}[0];
for(int x = 0; x < X; ++x)
{
int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1];
if(w_tmp >= 0 && w_tmp < WI && w_tmp % ConvStrides{}[1] == 0)
{
int wo = w_tmp / ConvStrides{}[1];
int e = c * (Y * X) + y * X + x;
int b = n * (HO * WO) + ho * WO + wo;
v += in_eb(e, b);
}
}
}
}
in_nchw(n, c, hi, wi) = v;
};
auto f_par = make_ParallelTensorFunctor(f,
in_nchw.mDesc.GetLengths()[0],
in_nchw.mDesc.GetLengths()[1],
in_nchw.mDesc.GetLengths()[2],
in_nchw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency());
}
#pragma once #pragma once
#include "tensor.hpp" #include "host_tensor.hpp"
template <class TIn, template <class TIn,
class TWei, class TWei,
...@@ -34,7 +34,8 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -34,7 +34,8 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3]) wi < in_nchw.mDesc.GetLengths()[3])
{ {
v += double(in_nchw(n, c, hi, wi)) * double(wei_kcyx(k, c, y, x)); v += static_cast<const double>(in_nchw(n, c, hi, wi)) *
static_cast<const double>(wei_kcyx(k, c, y, x));
} }
} }
} }
......
#pragma once #pragma once
#include "tensor.hpp" #include "host_tensor.hpp"
template <typename TIn, template <typename TIn,
typename TWei, typename TWei,
......
#ifndef TENSOR_HPP #ifndef HOST_TENSOR_HPP
#define TENSOR_HPP #define HOST_TENSOR_HPP
#include <thread> #include <thread>
#include <vector> #include <vector>
...@@ -65,26 +65,26 @@ auto construct_f_unpack_args(F, T args) ...@@ -65,26 +65,26 @@ auto construct_f_unpack_args(F, T args)
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{}); return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
} }
struct TensorDescriptor struct HostTensorDescriptor
{ {
TensorDescriptor() = delete; HostTensorDescriptor() = delete;
template <typename X> template <typename X>
TensorDescriptor(std::vector<X> lens); HostTensorDescriptor(std::vector<X> lens);
template <typename X, typename Y> template <typename X, typename Y>
TensorDescriptor(std::vector<X> lens, std::vector<Y> strides); HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
void CalculateStrides(); void CalculateStrides();
template <class Range> template <class Range>
TensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
template <class Range1, class Range2> template <class Range1, class Range2>
TensorDescriptor(const Range1& lens, const Range2& strides) HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
...@@ -205,7 +205,7 @@ struct Tensor ...@@ -205,7 +205,7 @@ struct Tensor
{ {
} }
Tensor(const TensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
template <class G> template <class G>
void GenerateTensorValue(G g, std::size_t num_thread = 1) void GenerateTensorValue(G g, std::size_t num_thread = 1)
...@@ -267,11 +267,11 @@ struct Tensor ...@@ -267,11 +267,11 @@ struct Tensor
typename std::vector<T>::const_iterator end() const { return mData.end(); } typename std::vector<T>::const_iterator end() const { return mData.end(); }
TensorDescriptor mDesc; HostTensorDescriptor mDesc;
std::vector<T> mData; std::vector<T> mData;
}; };
void ostream_TensorDescriptor(const TensorDescriptor& desc, std::ostream& os = std::cout) void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout)
{ {
os << "dim " << desc.GetNumOfDimension() << ", "; os << "dim " << desc.GetNumOfDimension() << ", ";
......
#ifndef TENSOR_GENERATOR_HPP #ifndef HOST_TENSOR_GENERATOR_HPP
#define TENSOR_GENERATOR_HPP #define HOST_TENSOR_GENERATOR_HPP
#include "config.hpp" #include "config.hpp"
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "print_array.hpp"
#include "print_sequence.hpp"
#include "device.hpp"
#include "tensor_generator.hpp"
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_col2im.hpp"
#include "device_col2im_eb_nchw.hpp"
int main(int argc, char* argv[])
{
using namespace ck;
#if 1
constexpr index_t N = 2;
constexpr index_t C = 8;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 4;
constexpr index_t X = 4;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<2, 2>;
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 384;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 384;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr index_t N = 128;
constexpr index_t C = 1280;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 384;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
constexpr index_t N = 64;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 384;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr index_t N = 128;
constexpr index_t C = 768;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
constexpr index_t N = 128;
constexpr index_t C = 528;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
constexpr index_t N = 128;
constexpr index_t C = 528;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 128;
constexpr index_t C = 832;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
constexpr index_t C = 288;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128;
constexpr index_t C = 48;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 128;
constexpr index_t Y = 5;
constexpr index_t X = 5;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#endif
constexpr auto img_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
img_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
constexpr index_t HO = out_nkhw_desc.GetLengths()[2];
constexpr index_t WO = out_nkhw_desc.GetLengths()[3];
constexpr auto col_eb_desc =
make_native_tensor_descriptor_packed(Sequence<C * Y * X, N * HO * WO>{});
using FilterSizes = Sequence<Y, X>;
using OutputSizes = Sequence<HO, WO>;
ostream_ConstantTensorDescriptor(col_eb_desc, std::cout << "col_eb_desc: ");
ostream_ConstantTensorDescriptor(img_nchw_desc, std::cout << "img_nchw_desc: ");
print_sequence("FilterSizes", FilterSizes{});
print_sequence("OutputSizes", OutputSizes{});
print_sequence("LeftPads", LeftPads{});
print_sequence("LeftPads", LeftPads{});
print_sequence("RightPads", RightPads{});
print_sequence("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{});
Tensor<float> col_eb(make_TensorDescriptor(col_eb_desc));
Tensor<float> img_nchw_host(make_TensorDescriptor(img_nchw_desc));
Tensor<float> img_nchw_device(make_TensorDescriptor(img_nchw_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3)
{
printf("arg1: do_verification, arg2: nrepeat\n");
exit(1);
}
bool do_verification = atoi(argv[1]);
std::size_t nrepeat = atoi(argv[2]);
if(do_verification)
{
#if 0
col_eb.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#else
col_eb.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#endif
}
device_col2im_eb_nchw(col_eb_desc,
col_eb,
img_nchw_desc,
img_nchw_device,
FilterSizes{},
OutputSizes{},
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
if(do_verification)
{
host_col2im(col_eb,
img_nchw_host,
FilterSizes{},
OutputSizes{},
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{});
check_error(img_nchw_host, img_nchw_device);
#if 0
LogRange(std::cout << "col_eb : ", col_eb.mData, ",") << std::endl;
LogRange(std::cout << "img_nchw_host : ", img_nchw_host.mData, ",") << std::endl;
LogRange(std::cout << "img_nchw_device : ", img_nchw_device.mData, ",") << std::endl;
#endif
}
}
col2im_driver.cpp
\ No newline at end of file
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "print_array.hpp" #include "print_array.hpp"
#include "print_sequence.hpp" #include "print_sequence.hpp"
#include "device.hpp" #include "device.hpp"
#include "tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
...@@ -23,7 +23,7 @@ int main(int argc, char* argv[]) ...@@ -23,7 +23,7 @@ int main(int argc, char* argv[])
{ {
using namespace launcher; using namespace launcher;
#if 1 #if 0
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 56; constexpr index_t HI = 56;
...@@ -160,10 +160,10 @@ int main(int argc, char* argv[]) ...@@ -160,10 +160,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -190,10 +190,10 @@ int main(int argc, char* argv[]) ...@@ -190,10 +190,10 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -209,19 +209,19 @@ int main(int argc, char* argv[]) ...@@ -209,19 +209,19 @@ int main(int argc, char* argv[])
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
print_sequence("LeftPads", LeftPads{}); print_sequence("LeftPads", LeftPads{});
print_sequence("LeftPads", LeftPads{}); print_sequence("LeftPads", LeftPads{});
print_sequence("RightPads", RightPads{}); print_sequence("RightPads", RightPads{});
print_sequence("ConvStrides", ConvStrides{}); print_sequence("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{}); print_sequence("ConvDilations", ConvDilations{});
Tensor<float> in_nchw_device(make_TensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> in_nchw_host(make_TensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
Tensor<float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); Tensor<float> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<float> out_nkhw(make_TensorDescriptor(out_nkhw_desc)); Tensor<float> out_nkhw(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -245,9 +245,9 @@ int main(int argc, char* argv[]) ...@@ -245,9 +245,9 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 1 #if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
...@@ -256,17 +256,17 @@ int main(int argc, char* argv[]) ...@@ -256,17 +256,17 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw, out_nkhw,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
if(do_verification) if(do_verification)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment