"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2171f77ac588be72e272ee2190836db434208fb2"
Commit 7a970877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 43cd8529
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned Y = wei_kcsr_desc.GetLength(I2);
constexpr unsigned X = wei_kcsr_desc.GetLength(I3);
// reorder weight
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// for 3x3, 34x34
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
#endif
constexpr unsigned GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_csrk_desc),
decltype(out_khwn_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW,
InBlockCopy_ThreadPerDimN>,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
OutThreadCopyDataPerWrite>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc, class LowerPads, class UpperPads>
void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned Y = wei_kcsr_desc.GetLength(I2);
constexpr unsigned X = wei_kcsr_desc.GetLength(I3);
// reorder weight
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
#endif
constexpr unsigned GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_csrk_desc),
decltype(out_khwn_desc),
LowerPads,
UpperPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcsr_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
constexpr unsigned Y = wei_kcsr_desc.GetLength(I2);
constexpr unsigned X = wei_kcsr_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
make_ParallelTensorFunctor(
[&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); },
N,
C,
Hi,
Wi)(std::thread::hardware_concurrency());
// convert wei_kcsr to wei_csrk
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
make_ParallelTensorFunctor(
[&](auto k, auto c, auto y, auto x) { wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); },
K,
C,
Y,
X)(std::thread::hardware_concurrency());
// conver out_nkhw to out_knhw
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 1
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
#endif
constexpr unsigned GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
// mem
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead +
BPerBlock)); // reserve extra space for BGhostRead
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
{
float time =
launch_kernel(gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer<
GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_csrk_desc),
decltype(out_khwn_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// convert out_khwn to out_nkhw
make_ParallelTensorFunctor(
[&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); },
N,
K,
Ho,
Wo)(std::thread::hardware_concurrency());
}
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include "conv_common.hip.hpp" #include "conv_common.hip.hpp"
#include "device_direct_convolution_1.hpp" #include "device_direct_convolution_1.hpp"
#include "device_direct_convolution_2.hpp" #include "device_direct_convolution_2.hpp"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp" #include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp" #include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp" #include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
...@@ -108,7 +108,7 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -108,7 +108,7 @@ auto make_TensorDescriptor(TConstTensorDesc)
template <class T, class LowerPads, class UpperPads> template <class T, class LowerPads, class UpperPads>
void host_direct_convolution( void host_direct_convolution(
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads) const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads)
{ {
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
...@@ -118,18 +118,18 @@ void host_direct_convolution( ...@@ -118,18 +118,18 @@ void host_direct_convolution(
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho + y - h_pad_low; int hi = ho + y - h_pad_low;
for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo + x - w_pad_low; int wi = wo + x - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3]) wi < in_nchw.mDesc.GetLengths()[3])
{ {
v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); v += in_nchw(n, c, hi, wi) * wei_kcyx(k, c, y, x);
} }
} }
} }
...@@ -148,7 +148,7 @@ void host_direct_convolution( ...@@ -148,7 +148,7 @@ void host_direct_convolution(
template <class T, class LowerPads, class UpperPads> template <class T, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution( void host_winograd_3x3_convolution(
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads) const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads)
{ {
constexpr std::size_t HoPerTile = 2; constexpr std::size_t HoPerTile = 2;
constexpr std::size_t WoPerTile = 2; constexpr std::size_t WoPerTile = 2;
...@@ -158,9 +158,9 @@ void host_winograd_3x3_convolution( ...@@ -158,9 +158,9 @@ void host_winograd_3x3_convolution(
std::size_t HI = in_nchw.mDesc.GetLengths()[2]; std::size_t HI = in_nchw.mDesc.GetLengths()[2];
std::size_t WI = in_nchw.mDesc.GetLengths()[3]; std::size_t WI = in_nchw.mDesc.GetLengths()[3];
std::size_t K = wei_kcsr.mDesc.GetLengths()[0]; std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
std::size_t Y = wei_kcsr.mDesc.GetLengths()[2]; std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
std::size_t X = wei_kcsr.mDesc.GetLengths()[3]; std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t HO = out.mDesc.GetLengths()[2];
std::size_t WO = out.mDesc.GetLengths()[3]; std::size_t WO = out.mDesc.GetLengths()[3];
...@@ -259,49 +259,49 @@ void host_winograd_3x3_convolution( ...@@ -259,49 +259,49 @@ void host_winograd_3x3_convolution(
}; };
auto f_wei_transform = [&](auto k, auto c) { auto f_wei_transform = [&](auto k, auto c) {
wei_transform(k, c, 0, 0) = wei_kcsr(k, c, 0, 0); wei_transform(k, c, 0, 0) = wei_kcyx(k, c, 0, 0);
wei_transform(k, c, 0, 1) = wei_transform(k, c, 0, 1) =
0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); 0.5 * wei_kcyx(k, c, 0, 0) + 0.5 * wei_kcyx(k, c, 0, 1) + 0.5 * wei_kcyx(k, c, 0, 2);
wei_transform(k, c, 0, 2) = wei_transform(k, c, 0, 2) =
0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); 0.5 * wei_kcyx(k, c, 0, 0) - 0.5 * wei_kcyx(k, c, 0, 1) + 0.5 * wei_kcyx(k, c, 0, 2);
wei_transform(k, c, 0, 3) = wei_kcsr(k, c, 0, 2); wei_transform(k, c, 0, 3) = wei_kcyx(k, c, 0, 2);
wei_transform(k, c, 1, 0) = wei_transform(k, c, 1, 0) =
0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); 0.5 * wei_kcyx(k, c, 0, 0) + 0.5 * wei_kcyx(k, c, 1, 0) + 0.5 * wei_kcyx(k, c, 2, 0);
wei_transform(k, c, 1, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + wei_transform(k, c, 1, 1) = 0.25 * wei_kcyx(k, c, 0, 0) + 0.25 * wei_kcyx(k, c, 0, 1) +
0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) + 0.25 * wei_kcyx(k, c, 0, 2) + 0.25 * wei_kcyx(k, c, 1, 0) +
0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + 0.25 * wei_kcyx(k, c, 1, 1) + 0.25 * wei_kcyx(k, c, 1, 2) +
0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + 0.25 * wei_kcyx(k, c, 2, 0) + 0.25 * wei_kcyx(k, c, 2, 1) +
0.25 * wei_kcsr(k, c, 2, 2); 0.25 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 1, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + wei_transform(k, c, 1, 2) = 0.25 * wei_kcyx(k, c, 0, 0) - 0.25 * wei_kcyx(k, c, 0, 1) +
0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) - 0.25 * wei_kcyx(k, c, 0, 2) + 0.25 * wei_kcyx(k, c, 1, 0) -
0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + 0.25 * wei_kcyx(k, c, 1, 1) + 0.25 * wei_kcyx(k, c, 1, 2) +
0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + 0.25 * wei_kcyx(k, c, 2, 0) - 0.25 * wei_kcyx(k, c, 2, 1) +
0.25 * wei_kcsr(k, c, 2, 2); 0.25 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 1, 3) = wei_transform(k, c, 1, 3) =
0.5 * wei_kcsr(k, c, 0, 2) + 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); 0.5 * wei_kcyx(k, c, 0, 2) + 0.5 * wei_kcyx(k, c, 1, 2) + 0.5 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 2, 0) = wei_transform(k, c, 2, 0) =
0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); 0.5 * wei_kcyx(k, c, 0, 0) - 0.5 * wei_kcyx(k, c, 1, 0) + 0.5 * wei_kcyx(k, c, 2, 0);
wei_transform(k, c, 2, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + wei_transform(k, c, 2, 1) = 0.25 * wei_kcyx(k, c, 0, 0) + 0.25 * wei_kcyx(k, c, 0, 1) +
0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) - 0.25 * wei_kcyx(k, c, 0, 2) - 0.25 * wei_kcyx(k, c, 1, 0) -
0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + 0.25 * wei_kcyx(k, c, 1, 1) - 0.25 * wei_kcyx(k, c, 1, 2) +
0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + 0.25 * wei_kcyx(k, c, 2, 0) + 0.25 * wei_kcyx(k, c, 2, 1) +
0.25 * wei_kcsr(k, c, 2, 2); 0.25 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 2, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + wei_transform(k, c, 2, 2) = 0.25 * wei_kcyx(k, c, 0, 0) - 0.25 * wei_kcyx(k, c, 0, 1) +
0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) + 0.25 * wei_kcyx(k, c, 0, 2) - 0.25 * wei_kcyx(k, c, 1, 0) +
0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + 0.25 * wei_kcyx(k, c, 1, 1) - 0.25 * wei_kcyx(k, c, 1, 2) +
0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + 0.25 * wei_kcyx(k, c, 2, 0) - 0.25 * wei_kcyx(k, c, 2, 1) +
0.25 * wei_kcsr(k, c, 2, 2); 0.25 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 2, 3) = wei_transform(k, c, 2, 3) =
0.5 * wei_kcsr(k, c, 0, 2) - 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); 0.5 * wei_kcyx(k, c, 0, 2) - 0.5 * wei_kcyx(k, c, 1, 2) + 0.5 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 3, 0) = wei_kcsr(k, c, 2, 0); wei_transform(k, c, 3, 0) = wei_kcyx(k, c, 2, 0);
wei_transform(k, c, 3, 1) = wei_transform(k, c, 3, 1) =
0.5 * wei_kcsr(k, c, 2, 0) + 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); 0.5 * wei_kcyx(k, c, 2, 0) + 0.5 * wei_kcyx(k, c, 2, 1) + 0.5 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 3, 2) = wei_transform(k, c, 3, 2) =
0.5 * wei_kcsr(k, c, 2, 0) - 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); 0.5 * wei_kcyx(k, c, 2, 0) - 0.5 * wei_kcyx(k, c, 2, 1) + 0.5 * wei_kcyx(k, c, 2, 2);
wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2); wei_transform(k, c, 3, 3) = wei_kcyx(k, c, 2, 2);
}; };
auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
...@@ -569,16 +569,16 @@ int main(int argc, char* argv[]) ...@@ -569,16 +569,16 @@ int main(int argc, char* argv[])
auto upper_pads = Sequence<HPad, WPad>{}; auto upper_pads = Sequence<HPad, WPad>{};
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads); in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
Tensor<float> in_nchw(make_TensorDescriptor(in_nchw_desc)); Tensor<float> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<float> wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); Tensor<float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor<float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); Tensor<float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
...@@ -597,13 +597,13 @@ int main(int argc, char* argv[]) ...@@ -597,13 +597,13 @@ int main(int argc, char* argv[])
{ {
#if 0 #if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#endif #endif
} }
...@@ -613,17 +613,17 @@ int main(int argc, char* argv[]) ...@@ -613,17 +613,17 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_direct_convolution_2 device_direct_convolution_2
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_2_chwn_csrk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(in_nchw_desc, device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
in_nchw, in_nchw,
wei_kcsr_desc, wei_kcyx_desc,
wei_kcsr, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
lower_pads, lower_pads,
...@@ -636,18 +636,18 @@ int main(int argc, char* argv[]) ...@@ -636,18 +636,18 @@ int main(int argc, char* argv[])
#if 1 #if 1
if(Y == 3 && X == 3) if(Y == 3 && X == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
else else
{ {
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#endif #endif
#if 0 #if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif #endif
......
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
class InBlockCopyThreadPerDims,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned OutThreadCopyDataPerWrite>
__global__ void
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned X = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin;
// flattend (2d) tensor view of gridwise weight
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_khwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy = Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_csrk_block_desc.GetStride(I0)>{});
constexpr auto b_cxwn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{});
const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
0,
in_chwn_block_desc.GetStride(I1),
out_khwn_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
// LDS: be careful of alignment
constexpr unsigned in_block_size =
in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin =
p_in_global + in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_begin =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads())
{
// input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
// weight: global mem to LDS
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
__syncthreads();
// a series of batched GEMM
for(unsigned y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
{
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread,
[](auto& acc, const auto&& v) { acc += v; });
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const unsigned ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif 1
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// this is for v2 GEMM
// output is a 8d tensor
if(NPerThread < NPerBlock && WoPerThread == 1)
{
constexpr unsigned N1_ = GemmNPerThreadSubC;
constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
constexpr unsigned K2_ = GemmMPerThreadSubC;
constexpr unsigned K1_ = KPerBlock / KPerThread;
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{});
constexpr auto out_8d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, HoPerThread, WoPerBlock / W1_, 1, 1, N1_>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
}
#endif
threadwise_8d_tensor_copy(out_8d_thread_desc,
p_out_thread,
out_8d_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_8d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{});
}
else if(NPerThread == NPerBlock)
{
// not implemented yet
assert(false);
}
else
{
assert(false);
}
#endif
}
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class LowerPads,
class UpperPads,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(
const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned X = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned HPadLow = LowerPads{}.Get(I0);
constexpr unsigned WPadLow = LowerPads{}.Get(I1);
constexpr unsigned HPadUp = UpperPads{}.Get(I0);
constexpr unsigned WPadUp = UpperPads{}.Get(I1);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
// flattened (2d) tensor view of wei in global mem
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight in LDS
constexpr auto in_chwn_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
constexpr auto wei_csrk_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, Y, X, KPerBlock>{});
// flattened (2d) tensor view of wei in LDS
constexpr auto wei_ek_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock * Y * X, KPerBlock>{});
// tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr auto blockwise_in_copy =
BlockwiseChwnTensorCopyPadded<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
LowerPads>{};
#if 0
// weight: format is [C,Y,X,K]
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_csrk_global_desc),
decltype(wei_csrk_block_desc),
decltype(wei_csrk_block_desc.GetLengths())>{};
#elif 0
// weight: format is [C*Y*X,K]
constexpr auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 1
// weight: format is [C*Y*X,K]
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_csrk_block_desc.GetStride(I0)>{});
constexpr auto b_cxwn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
const auto blockwise_batch_gemm =
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
true,
false,
false,
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0),
HoPerBlock,
HoPerThread,
CPerThread,
true>{};
// LDS
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
// register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
const Float* p_wei_global_block_begin =
p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
__syncthreads())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global,
c_block_data_begin,
ho_block_data_begin,
wo_block_data_begin,
n_block_data_begin,
p_in_block,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
#endif
__syncthreads();
// a series of batched GEMM
for(unsigned y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread,
f_accum);
}
}
}
const auto matrix_c_index =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
}
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_buffer(
const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1);
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2);
constexpr unsigned N = in_chwn_global_desc.GetLength(I3);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned Y = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned X = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_csrk_block_desc.GetStride(I0)>{});
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
constexpr auto c_kxb_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
// LDS double buffer
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)];
__shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)];
const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
// preload data into LDS
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0);
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
bool even_loop = true;
for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
even_loop = !even_loop)
{
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0;
Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0;
__syncthreads();
// load next data
#if 1
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif 1
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
#endif
// compute on current data
// a series of GEMM
for(unsigned y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x,
p_out_thread,
f_accum);
}
}
#if 0
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#endif
}
// last computation
{
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
__syncthreads();
for(unsigned y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x,
p_out_thread,
f_accum);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned h_data = b_data / (Wi * N);
unsigned itmp = b_data - h_data * (Wi * N);
unsigned w_data = itmp / N;
unsigned n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo)
{
p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment