Commit f9cf57d4 authored by carlushuang's avatar carlushuang
Browse files

support YXCK filter

parent 71254ddd
#include <sstream>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_PASSTHROUGH
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T>
static bool
check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0)
{
int error_count = 0;
float max_diff = 1e-5;
double square_difference = .0;
double mag1 = .0;
double mag2 = .0;
for(int i = 0; i < ref.mData.size(); ++i)
{
double ri = (double)ref.mData[i];
double pi = (double)result.mData[i];
double d = ri - pi;
if(per_pixel_check)
{
if(max_diff < std::abs(d))
{
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
d);
}
}
square_difference += d * d;
if(std::abs(mag1) < std::abs(ri))
mag1 = ri;
if(std::abs(mag2) < std::abs(pi))
mag2 = pi;
}
double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
if(computed_nrms >= nrms)
printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n",
computed_nrms,
mag1,
mag2,
nrms);
return computed_nrms < nrms && error_count == 0;
}
float calculate_gflops() {}
template <typename T>
void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = K / 8;
ck::index_t row = 8;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 2;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
data_type = 0;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
}
else if(argc == 18)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
conv_stride_h = std::stoi(argv[10]);
conv_stride_w = std::stoi(argv[11]);
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
}
else
{
printf("arg1: data type (0=fp32, 1=fp16)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t H_,
std::size_t W_) {
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
int per_pixel_check = 0;
switch(init_method)
{
case 0:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
case 3:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for(auto i_n = 0; i_n < N; i_n++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_hi = 0; i_hi < Hi; i_hi++)
{
for(auto i_wi = 0; i_wi < Wi; i_wi++)
{
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
}
}
}
}
for(auto i_k = 0; i_k < K; i_k++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_y = 0; i_y < Y; i_y++)
{
for(auto i_x = 0; i_x < X; i_x++)
{
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
}
}
}
}
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{}, 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result,
out_n_k_ho_wo_device_result,
1e-6,
per_pixel_check))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
#include <sstream>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_PASSTHROUGH
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
// ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
// ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
// ------------------ nhwc-yxck-nhwk
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T>
static bool
check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0)
{
int error_count = 0;
float max_diff = 1e-5;
double square_difference = .0;
double mag1 = .0;
double mag2 = .0;
for(int i = 0; i < ref.mData.size(); ++i)
{
double ri = (double)ref.mData[i];
double pi = (double)result.mData[i];
double d = ri - pi;
if(per_pixel_check)
{
if(max_diff < std::abs(d))
{
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
d);
}
}
square_difference += d * d;
if(std::abs(mag1) < std::abs(ri))
mag1 = ri;
if(std::abs(mag2) < std::abs(pi))
mag2 = pi;
}
double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
if(computed_nrms >= nrms)
printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n",
computed_nrms,
mag1,
mag2,
nrms);
return computed_nrms < nrms && error_count == 0;
}
float calculate_gflops() {}
template <typename T>
void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = K / 8;
ck::index_t row = 8;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
template <typename T>
void transpose_kyxc_2_yxck(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = 1;
ck::index_t row = K;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 2;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
data_type = 0;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
}
else if(argc == 18)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
conv_stride_h = std::stoi(argv[10]);
conv_stride_w = std::stoi(argv[11]);
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
}
else
{
printf("arg1: data type (0=fp32, 1=fp16)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t H_,
std::size_t W_) {
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor<WeiDataType> wei_y_x_c_k(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
int per_pixel_check = 0;
switch(init_method)
{
case 0:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
case 3:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for(auto i_n = 0; i_n < N; i_n++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_hi = 0; i_hi < Hi; i_hi++)
{
for(auto i_wi = 0; i_wi < Wi; i_wi++)
{
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
}
}
}
}
for(auto i_k = 0; i_k < K; i_k++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_y = 0; i_y < Y; i_y++)
{
for(auto i_x = 0; i_x < X; i_x++)
{
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
}
}
}
}
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck(wei_y_x_c_k, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_y_x_c_k.mData.data());
#endif
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{}, 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result,
out_n_k_ho_wo_device_result,
1e-6,
per_pixel_check))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
......@@ -16,7 +16,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
#define TEST_LAYOUT_NHWC_YXCK_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using F32 = float;
using F16 = ck::half_t;
......@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
// ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
......@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
// ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
......@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
// ------------------ nhwc-yxck-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
......@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
}
}
template <typename T>
void transpose_kyxc_2_yxck(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = 1;
ck::index_t row = K;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[])
{
int data_type = 0;
......@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor<WeiDataType> wei_y_x_c_k(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
......@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck(wei_y_x_c_k, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_y_x_c_k.mData.data());
#endif
bias_device_buf.ToDevice(bias.mData.data());
resi_device_buf.ToDevice(residual.mData.data());
......@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
conv_ptrs);
}
#endif
}
......
......@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const auto k_per_block = a_slice_length[Number<1>{}];
const auto m_per_block = c_slice_length[Number<0>{}];
const auto n_per_block = c_slice_length[Number<1>{}];
......@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// m_per_block, n_per_block, k_per_block);
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u, mpt:%u, npt:%u\n",
// lda,
// ldb,
// ldc,
// m_per_block,
// n_per_block,
// k_per_block,
// m_per_thread,
// n_per_thread);
// fflush(stdout);
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{
......
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
!UseALocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
!UseBLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
WeiDataType,
OutDataType,
AGridDesc,
BGridDesc,
CGridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
if constexpr(!UseALocalBuffer &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
if constexpr(!UseBLocalBuffer)
{
if(!(arg.Conv_K_ % 8 == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename BiasDataType,
typename AddDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp =
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using C0DataType = BiasDataType;
using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
if constexpr(BiasAlongGemmM)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ADataType,
ADataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
!UseALocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
BDataType,
BDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
!UseBLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
BDataType, // WeiDataType,
CDataType, // OutDataType,
C0DataType, // C0DataType
C1DataType, // C1DataType
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
C0GridDesc, // C0GridDesc,
C1GridDesc, // C1GridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
p_c0_grid_{p_bias_grid},
p_c1_grid_{p_add_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
c0_grid_desc_{},
c1_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
GetGemmN(K));
c1_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
C0GridDesc c0_grid_desc_;
C1GridDesc c1_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel =
ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm,
ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
AGridDesc,
BGridDesc,
CGridDesc,
C0GridDesc,
C1GridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
if constexpr(!UseALocalBuffer &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
if constexpr(!UseBLocalBuffer)
{
if(!(arg.Conv_K_ % 8 == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
p_bias_grid,
p_add_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
static_cast<const BiasDataType*>(p_bias_grid),
static_cast<const AddDataType*>(p_add_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
......@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n"
".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
......@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n"
".else\n"
".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n"
".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%rax, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
......@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n"
".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx(r9), lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n"
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vmovups_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n"
".else\n"
".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n"
".endif\n"
".endm\n"
......@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".else\n"
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rdx, %%rdx, 2), %%rdi\n"
"lea (%%rbx, %%rdi), %%r9\n"
".endif\n"
"cmp $4, %%rsi\n"
......@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n"
" lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n"
" lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%r9, %%rdx, 4), %%r9\n"
".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n"
......@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n"
" lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n"
" lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%r9, %%rdx, 1), %%r9\n"
".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n"
......@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
else
{
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m);
ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
}
}
else
......@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
else
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m)));
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
}
}
};
......@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8);
ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
}
else
{
......@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8)));
reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
}
else
{
......@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4;
} else{
p_a += Mr * 4;
p_a += lda * 4;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4;
p_b += ldb * 4;
}else{
p_b += 4 * 8;
}
......@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1;
} else{
p_a += Mr * 1;
p_a += lda * 1;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1;
p_b += ldb * 1;
}else{
p_b += 1 * 8;
}
......@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
......@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n"
".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
......@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif\n"
".else\n"
".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n"
".if (\\i_k == 0) || (\\i_k == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
......@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vmovups_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n"
".else\n"
".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n"
".if (\\i_k == 0) || (\\i_k == 1)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n"
".endif\n"
".endm\n"
......@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_Mr > 2\n"
"lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n"
".else\n"
"lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rbx, %%rdx, 2), %%rdi\n"
".endif\n"
"cmp $4, %%rsi\n"
......@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n"
" lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n"
" lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%rdi, %%rdx, 4), %%rdi\n"
".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n"
......@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n"
" lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n"
" lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%rdi, %%rdx, 1), %%rdi\n"
".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n"
......@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
else
{
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m);
ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
}
}
else
......@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
else
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m)));
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
}
}
};
......@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8);
ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
}
else
{
......@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8)));
reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
}
else
{
......@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4;
} else{
p_a += Mr * 4;
p_a += lda * 4;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4;
p_b += ldb * 4;
}else{
p_b += 4 * 8;
}
......@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1;
} else{
p_a += Mr * 1;
p_a += lda * 1;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1;
p_b += ldb * 1;
}else{
p_b += 1 * 8;
}
......
......@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK(
const SrcDesc& src_desc,
const Index&,
const DstDesc&,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_k = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_n = src_slice_origin_idx[Number<1>{}];
src_offset = idx_k * GemmN + idx_n;
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
else
{
const ck::index_t k_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// k * n
index_t i_k_itr = k_per_block;
while(i_k_itr >= 8)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * n_per_block, p_src + 4 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * n_per_block, p_src + 5 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * n_per_block, p_src + 6 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * n_per_block, p_src + 7 * GemmN, n_per_block, element_op_);
i_k_itr -= 8;
p_dst += 8 * n_per_block;
p_src += 8 * GemmN;
}
if(i_k_itr & 4)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
p_dst += 4 * n_per_block;
p_src += 4 * GemmN;
}
if(i_k_itr & 2)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
p_dst += 2 * n_per_block;
p_src += 2 * GemmN;
}
if(i_k_itr & 1)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<0>{}];
ck::index_t move_n = src_slice_origin_step_idx[Number<1>{}];
src_offset += move_k * GemmN + move_n;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
......
......@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
)
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
......@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
)
add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
......@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
int max_threads = omp_get_max_threads();
auto invoke_uk = [&](ck::cpu::ThreadwiseGemmParam& param, float* current_mat_c) {
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value)
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{
assert(m % uk.ThreadMr == 0 && n == uk.ThreadNr);
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
if constexpr(std::is_same<Row, ALayout>::value)
{
uk.Run(&param);
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_c = p_c;
param.p_a = mat_a + i_m * k;
}
}
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
else
{
float* p_c_n = p_c;
FloatB* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{
uk.Run(&param);
p_b_n += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c_n += uk.ThreadNr;
param.p_b = p_b_n;
param.p_c = p_c_n;
}
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
param.p_a = mat_a + i_m;
}
}
else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m == uk.ThreadMr && n == uk.ThreadNr);
uk.Run(&param);
}
else
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatB* p_b = mat_b;
float* p_c = current_mat_c;
param.p_b = p_b;
param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{
if constexpr(std::is_same<Row, BLayout>::value)
{
param.p_b = mat_b + i_n;
}
else
{
param.p_b = mat_b + i_n * k;
}
param.p_c = current_mat_c + i_m * n + i_n;
uk.Run(&param);
p_b += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c += uk.ThreadNr;
param.p_b = p_b;
param.p_c = p_c;
}
}
};
......@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
}
// implement small ukernel on L1
template <typename FloatA, typename FloatB, typename ALayout, typename BLayout>
template <typename FloatA,
typename FloatB,
typename ALayout,
typename BLayout,
typename thread_gemm_instance>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{
int max_threads = omp_get_max_threads();
......@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
k);
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
if(m % uk_type::ThreadMr != 0 || n % uk_type::ThreadNr != 0)
return;
if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value) ||
(n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
// only k is the fast changing dim of A/B can we do muldiplt m, n
return;
// if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value)
// ||
// (n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
// // only k is the fast changing dim of A/B can we do muldiplt m, n
// return;
if(found)
return;
......@@ -435,8 +402,21 @@ int main(int argc, char** argv)
omp_set_num_threads(1);
printf("max threads:%d\n", omp_get_max_threads());
test_cpu_ukernel<AType, BType, Row, Row>(alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row>(alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col>(alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_4x24_instances<Row, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_4x24_instances<Row, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_4x24_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_4x24_instances<Col, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_6x16_instances<Row, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_6x16_instances<Row, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_6x16_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_6x16_instances<Col, Col>>(
alpha, m, n, k);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment