"docs/en/understand_mmcv/config.md" did not exist on "be1be020271997d38fbfa7fa2a97fbe34c322407"
Commit f9cf57d4 authored by carlushuang's avatar carlushuang
Browse files

support YXCK filter

parent 71254ddd
#include <sstream> #include <sstream>
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp" #include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <omp.h> #include <omp.h>
#define AVX2_DATA_ALIGNMENT 32 #define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0 #define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1 #define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_PASSTHROUGH #define TEST_FUSION TEST_FUSION_PASSTHROUGH
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK #define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
using F32 = float;
using F16 = ck::half_t; using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation { namespace ck {
namespace cpu { namespace tensor_operation {
namespace device { namespace cpu {
namespace device_conv2d_fwd_avx2_instance { namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); // ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk( std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
// ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
} // namespace device std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace cpu
} // namespace tensor_operation void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(
} // namespace ck std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; // ------------------ nhwc-yxck-nhwk
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
#if TEST_FUSION == TEST_FUSION_RELU std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
template <typename T>
static bool void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0) std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
{
int error_count = 0; void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
float max_diff = 1e-5; std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
double square_difference = .0; } // namespace device_conv2d_fwd_avx2_instance
double mag1 = .0; } // namespace device
double mag2 = .0; } // namespace cpu
} // namespace tensor_operation
for(int i = 0; i < ref.mData.size(); ++i) } // namespace ck
{
double ri = (double)ref.mData[i]; using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
double pi = (double)result.mData[i]; using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
double d = ri - pi; #if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
if(per_pixel_check) #endif
{ #if TEST_FUSION == TEST_FUSION_RELU
if(max_diff < std::abs(d)) using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
{ #endif
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n", template <typename T>
i, static bool
double(ref.mData[i]), check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0)
double(result.mData[i]), {
d); int error_count = 0;
} float max_diff = 1e-5;
}
double square_difference = .0;
square_difference += d * d; double mag1 = .0;
if(std::abs(mag1) < std::abs(ri)) double mag2 = .0;
mag1 = ri;
if(std::abs(mag2) < std::abs(pi)) for(int i = 0; i < ref.mData.size(); ++i)
mag2 = pi; {
} double ri = (double)ref.mData[i];
double pi = (double)result.mData[i];
double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()}); double d = ri - pi;
double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
if(per_pixel_check)
if(computed_nrms >= nrms) {
printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n", if(max_diff < std::abs(d))
computed_nrms, {
mag1, error_count++;
mag2, printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
nrms); i,
double(ref.mData[i]),
return computed_nrms < nrms && error_count == 0; double(result.mData[i]),
} d);
}
float calculate_gflops() {} }
template <typename T> square_difference += d * d;
void transpose_kyxc_2_kyxc8k(Tensor<T>& dst, if(std::abs(mag1) < std::abs(ri))
const Tensor<T>& src, mag1 = ri;
ck::index_t K, if(std::abs(mag2) < std::abs(pi))
ck::index_t Y, mag2 = pi;
ck::index_t X, }
ck::index_t C)
{ double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
ck::index_t batch = K / 8; double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
ck::index_t row = 8;
ck::index_t col = C * Y * X; if(computed_nrms >= nrms)
for(auto i_b = 0; i_b < batch; i_b++) printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n",
{ computed_nrms,
for(auto i_r = 0; i_r < row; i_r++) mag1,
{ mag2,
for(auto i_c = 0; i_c < col; i_c++) nrms);
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c; return computed_nrms < nrms && error_count == 0;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r; }
dst.mData[dst_idx] = src.mData[src_idx];
} float calculate_gflops() {}
}
} template <typename T>
} void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
const Tensor<T>& src,
int main(int argc, char* argv[]) ck::index_t K,
{ ck::index_t Y,
int data_type = 0; ck::index_t X,
int init_method = 0; ck::index_t C)
{
// Conv shape ck::index_t batch = K / 8;
ck::index_t N = 2; ck::index_t row = 8;
ck::index_t K = 256; ck::index_t col = C * Y * X;
ck::index_t C = 192; for(auto i_b = 0; i_b < batch; i_b++)
ck::index_t Y = 3; {
ck::index_t X = 3; for(auto i_r = 0; i_r < row; i_r++)
ck::index_t Hi = 71; {
ck::index_t Wi = 71; for(auto i_c = 0; i_c < col; i_c++)
ck::index_t conv_stride_h = 1; {
ck::index_t conv_stride_w = 1; ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t conv_dilation_h = 1; ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
ck::index_t conv_dilation_w = 1; dst.mData[dst_idx] = src.mData[src_idx];
ck::index_t in_left_pad_h = 1; }
ck::index_t in_left_pad_w = 1; }
ck::index_t in_right_pad_h = 1; }
ck::index_t in_right_pad_w = 1; }
if(argc == 1) template <typename T>
{ void transpose_kyxc_2_yxck(Tensor<T>& dst,
data_type = 0; const Tensor<T>& src,
init_method = 1; ck::index_t K,
} ck::index_t Y,
else if(argc == 3) ck::index_t X,
{ ck::index_t C)
data_type = std::stoi(argv[1]); {
init_method = std::stoi(argv[2]); ck::index_t batch = 1;
} ck::index_t row = K;
else if(argc == 18) ck::index_t col = C * Y * X;
{ for(auto i_b = 0; i_b < batch; i_b++)
data_type = std::stoi(argv[1]); {
init_method = std::stoi(argv[2]); for(auto i_r = 0; i_r < row; i_r++)
{
N = std::stoi(argv[3]); for(auto i_c = 0; i_c < col; i_c++)
K = std::stoi(argv[4]); {
C = std::stoi(argv[5]); ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
Y = std::stoi(argv[6]); ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
X = std::stoi(argv[7]); dst.mData[dst_idx] = src.mData[src_idx];
Hi = std::stoi(argv[8]); }
Wi = std::stoi(argv[9]); }
conv_stride_h = std::stoi(argv[10]); }
conv_stride_w = std::stoi(argv[11]); }
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]); int main(int argc, char* argv[])
in_left_pad_h = std::stoi(argv[14]); {
in_left_pad_w = std::stoi(argv[15]); int data_type = 0;
in_right_pad_h = std::stoi(argv[16]); int init_method = 0;
in_right_pad_w = std::stoi(argv[17]);
} // Conv shape
else ck::index_t N = 2;
{ ck::index_t K = 256;
printf("arg1: data type (0=fp32, 1=fp16)\n"); ck::index_t C = 192;
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); ck::index_t Y = 3;
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " ck::index_t X = 3;
"RightPx\n"); ck::index_t Hi = 71;
exit(1); ck::index_t Wi = 71;
} ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
auto Run = [&](auto input_type, auto wei_type, auto out_type) { ck::index_t conv_dilation_h = 1;
using InDataType = decltype(input_type); ck::index_t conv_dilation_w = 1;
using WeiDataType = decltype(wei_type); ck::index_t in_left_pad_h = 1;
using OutDataType = decltype(out_type); ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType, ck::index_t in_right_pad_w = 1;
WeiDataType,
OutDataType, if(argc == 1)
InElementOp, {
WeiElementOp, data_type = 0;
OutElementOp>; init_method = 1;
}
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; else if(argc == 3)
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; {
data_type = std::stoi(argv[1]);
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; init_method = std::stoi(argv[2]);
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; }
else if(argc == 18)
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}}; {
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}}; data_type = std::stoi(argv[1]);
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}}; init_method = std::stoi(argv[2]);
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; N = std::stoi(argv[3]);
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}}; K = std::stoi(argv[4]);
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}}; C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
auto f_host_tensor_descriptor = [](std::size_t N_, X = std::stoi(argv[7]);
std::size_t C_, Hi = std::stoi(argv[8]);
std::size_t H_, Wi = std::stoi(argv[9]);
std::size_t W_) { conv_stride_h = std::stoi(argv[10]);
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}), conv_stride_w = std::stoi(argv[11]);
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_})); conv_dilation_h = std::stoi(argv[12]);
}; conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); in_left_pad_w = std::stoi(argv[15]);
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); in_right_pad_h = std::stoi(argv[16]);
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK in_right_pad_w = std::stoi(argv[17]);
Tensor<WeiDataType> wei_k_c_y_x_k8( }
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data else
#endif {
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo)); printf("arg1: data type (0=fp32, 1=fp16)\n");
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo)); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl; "RightPx\n");
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl; exit(1);
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl; }
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w auto Run = [&](auto input_type, auto wei_type, auto out_type) {
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w using InDataType = decltype(input_type);
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w using WeiDataType = decltype(wei_type);
<< ", Threads:" << omp_get_max_threads() << std::endl; using OutDataType = decltype(out_type);
int per_pixel_check = 0; using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
switch(init_method) WeiDataType,
{ OutDataType,
case 0: InElementOp,
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); WeiElementOp,
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{}); OutElementOp>;
per_pixel_check = 1;
break; const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
case 1: const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{}); const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
per_pixel_check = 1; const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
break; const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
case 2: const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
break;
auto f_host_tensor_descriptor = [](std::size_t N_,
case 3: std::size_t C_,
std::size_t H_,
#define PACK_32(v24, v16, v8, v0) \ std::size_t W_) {
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0)) return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
for(auto i_n = 0; i_n < N; i_n++) };
{
for(auto i_c = 0; i_c < C; i_c++) Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
{ Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
for(auto i_hi = 0; i_hi < Hi; i_hi++) #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
{ Tensor<WeiDataType> wei_k_c_y_x_k8(
for(auto i_wi = 0; i_wi < Wi; i_wi++) f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
{ #endif
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi); #if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v); Tensor<WeiDataType> wei_y_x_c_k(
} f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
} #endif
} Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
} Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
for(auto i_k = 0; i_k < K; i_k++) std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
{ std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
for(auto i_c = 0; i_c < C; i_c++) std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
{ std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
for(auto i_y = 0; i_y < Y; i_y++) << ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
{ << ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
for(auto i_x = 0; i_x < X; i_x++) << ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
{ << ", Threads:" << omp_get_max_threads() << std::endl;
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v); int per_pixel_check = 0;
} switch(init_method)
} {
} case 0:
} in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
break; wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
default: per_pixel_check = 1;
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1}); break;
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1}); case 1:
}
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(), // in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
AVX2_DATA_ALIGNMENT); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
DeviceAlignedMemCPU wei_device_buf( // wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT); per_pixel_check = 1;
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) * break;
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT); case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK break;
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
#endif case 3:
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C); #define PACK_32(v24, v16, v8, v0) \
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data()); (((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
#endif
// get host result for(auto i_n = 0; i_n < N; i_n++)
{ {
auto ref_conv = ReferenceConvFwdInstance{}; for(auto i_c = 0; i_c < C; i_c++)
auto ref_invoker = ref_conv.MakeInvoker(); {
for(auto i_hi = 0; i_hi < Hi; i_hi++)
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, {
wei_k_c_y_x, for(auto i_wi = 0; i_wi < Wi; i_wi++)
out_n_k_ho_wo_host_result, {
conv_filter_strides, uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
conv_filter_dilations, in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
input_left_pads, }
input_right_pads, }
InElementOp{}, }
WeiElementOp{}, }
OutElementOp{});
ref_invoker.Run(ref_argument); for(auto i_k = 0; i_k < K; i_k++)
} {
for(auto i_c = 0; i_c < C; i_c++)
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; {
using Relu = ck::tensor_operation::cpu::element_wise::Relu; for(auto i_y = 0; i_y < Y; i_y++)
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH {
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device:: for(auto i_x = 0; i_x < X; i_x++)
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>; {
#endif uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
#if TEST_FUSION == TEST_FUSION_RELU wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
using DeviceConvFwdNoOpPtr = }
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>; }
#endif }
}
// add device Conv instances break;
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> && wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && }
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK AVX2_DATA_ALIGNMENT);
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH DeviceAlignedMemCPU wei_device_buf(
if(omp_get_max_threads() > 1) sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
{ DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs); AVX2_DATA_ALIGNMENT);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
} #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
else wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
{ #endif
if(K % 8 == 0) #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs); wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
else #endif
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: #if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs); transpose_kyxc_2_yxck(wei_y_x_c_k, wei_k_c_y_x, K, Y, X, C);
} wei_device_buf.ToDevice(wei_y_x_c_k.mData.data());
#endif #endif
#if TEST_FUSION == TEST_FUSION_RELU // get host result
if(omp_get_max_threads() > 1) {
{ auto ref_conv = ReferenceConvFwdInstance{};
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: auto ref_invoker = ref_conv.MakeInvoker();
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs); wei_k_c_y_x,
} out_n_k_ho_wo_host_result,
else conv_filter_strides,
{ conv_filter_dilations,
if(K % 8 == 0) input_left_pads,
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: input_right_pads,
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs); InElementOp{},
else WeiElementOp{},
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: OutElementOp{});
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs); ref_invoker.Run(ref_argument);
} }
#endif
#endif using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH #if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1) using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
{ DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: #endif
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs); #if TEST_FUSION == TEST_FUSION_RELU
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: using DeviceConvFwdNoOpPtr =
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs); ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
} #endif
else
{ // add device Conv instances
if(K % 8 == 0) std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs); if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
else ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs); {
} #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#endif #if TEST_FUSION == TEST_FUSION_PASSTHROUGH
#if TEST_FUSION == TEST_FUSION_RELU if(omp_get_max_threads() > 1)
if(omp_get_max_threads() > 1) {
{ ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(conv_ptrs); ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs); }
} else
else {
{ if(K % 8 == 0)
if(K % 8 == 0) ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs); else
else ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(conv_ptrs); }
} #endif
#endif #if TEST_FUSION == TEST_FUSION_RELU
#endif if(omp_get_max_threads() > 1)
} {
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
if(conv_ptrs.size() <= 0) add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
{ ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
throw std::runtime_error("wrong! no device Conv instance found"); add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
} }
else
// profile device Conv instances {
bool success = true; if(K % 8 == 0)
double fastest_kernel_time = std::numeric_limits<double>::max(); ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
std::string fastest_kernel_name = ""; add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
double fastest_kernel_gflops = 0; else
for(auto& conv_ptr : conv_ptrs) ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
{ add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
auto argument_ptr = conv_ptr->MakeArgumentPointer( }
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), #endif
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), #endif
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
N, #if TEST_FUSION == TEST_FUSION_PASSTHROUGH
K, if(omp_get_max_threads() > 1)
C, {
input_spatial_lengths, ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
filter_spatial_lengths, add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
output_spatial_lengths, ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
conv_filter_strides, add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
conv_filter_dilations, }
input_left_pads, else
input_right_pads, {
InElementOp{}, if(K % 8 == 0)
WeiElementOp{}, ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
OutElementOp{}); add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
if(conv_ptr->IsSupportedArgument(argument_ptr.get())) ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
{ add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
auto invoker_ptr = conv_ptr->MakeInvokerPointer(); }
double time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{}, 10); #endif
#if TEST_FUSION == TEST_FUSION_RELU
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X; if(omp_get_max_threads() > 1)
{
double gflops = (total_flop * 1e-6) / time; ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(conv_ptrs);
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
if(!check_out(out_n_k_ho_wo_host_result, }
out_n_k_ho_wo_device_result, else
1e-6, {
per_pixel_check)) if(K % 8 == 0)
{ ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
success = false; else
} ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
else add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(conv_ptrs);
{ }
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time #endif
<< "ms, Gflops:" << gflops << std::endl; #endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(time < fastest_kernel_time) #if TEST_FUSION == TEST_FUSION_PASSTHROUGH
{ if(omp_get_max_threads() > 1)
fastest_kernel_time = time; {
fastest_kernel_name = conv_ptr->GetTypeString(); ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
fastest_kernel_gflops = gflops; add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
} ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
} add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(conv_ptrs);
} }
else else
{ {
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl; if(K % 8 == 0)
} ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
} add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(conv_ptrs);
else
if(fastest_kernel_time != std::numeric_limits<double>::max()) ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
{ add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(conv_ptrs);
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time }
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl; #endif
} #if TEST_FUSION == TEST_FUSION_RELU
return 0; if(omp_get_max_threads() > 1)
// if(success) {
// { ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
// std::cout << "test conv2d fwd cpu : Pass" << std::endl; add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(conv_ptrs);
// return 0; ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
// } add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(conv_ptrs);
// else }
// { else
// std::cout << "test conv2d fwd cpu: Fail " << std::endl; {
// return -1; if(K % 8 == 0)
// } ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
}; add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(conv_ptrs);
else
if(data_type == 0) ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
{ add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(conv_ptrs);
return Run(F32(), F32(), F32()); }
} #endif
else #endif
{ }
return 1;
} if(conv_ptrs.size() <= 0)
} {
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{}, 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result,
out_n_k_ho_wo_device_result,
1e-6,
per_pixel_check))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK #define TEST_LAYOUT_NHWC_YXCK_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance { ...@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
// ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
...@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt( ...@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
// ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
...@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( ...@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
// ------------------ nhwc-yxck-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst, ...@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
} }
} }
template <typename T>
void transpose_kyxc_2_yxck(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = 1;
ck::index_t row = K;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
int data_type = 0; int data_type = 0;
...@@ -243,6 +284,10 @@ int main(int argc, char* argv[]) ...@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8( Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor<WeiDataType> wei_y_x_c_k(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif #endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo)); Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo)); Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
...@@ -319,6 +364,10 @@ int main(int argc, char* argv[]) ...@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C); transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck(wei_y_x_c_k, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_y_x_c_k.mData.data());
#endif #endif
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
resi_device_buf.ToDevice(residual.mData.data()); resi_device_buf.ToDevice(residual.mData.data());
...@@ -404,6 +453,30 @@ int main(int argc, char* argv[]) ...@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
conv_ptrs);
}
#endif #endif
} }
......
...@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN ...@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB); auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC); auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const auto k_per_block = a_slice_length[Number<1>{}]; const auto k_per_block = a_slice_length[Number<1>{}];
const auto m_per_block = c_slice_length[Number<0>{}]; const auto m_per_block = c_slice_length[Number<0>{}];
const auto n_per_block = c_slice_length[Number<1>{}]; const auto n_per_block = c_slice_length[Number<1>{}];
...@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN ...@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
param.alpha = 1.0f; // TODO param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0; param.accmulate_c = is_accumulate_c ? 1 : 0;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc, // printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u, mpt:%u, npt:%u\n",
// m_per_block, n_per_block, k_per_block); // lda,
// ldb,
// ldc,
// m_per_block,
// n_per_block,
// k_per_block,
// m_per_thread,
// n_per_thread);
// fflush(stdout);
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value) if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{ {
......
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
!UseALocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
!UseBLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
WeiDataType,
OutDataType,
AGridDesc,
BGridDesc,
CGridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
if constexpr(!UseALocalBuffer &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
if constexpr(!UseBLocalBuffer)
{
if(!(arg.Conv_K_ % 8 == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename BiasDataType,
typename AddDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp =
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using C0DataType = BiasDataType;
using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
if constexpr(BiasAlongGemmM)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ADataType,
ADataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
!UseALocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
BDataType,
BDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
!UseBLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
BDataType, // WeiDataType,
CDataType, // OutDataType,
C0DataType, // C0DataType
C1DataType, // C1DataType
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
C0GridDesc, // C0GridDesc,
C1GridDesc, // C1GridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
p_c0_grid_{p_bias_grid},
p_c1_grid_{p_add_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
c0_grid_desc_{},
c1_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
GetGemmN(K));
c1_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
C0GridDesc c0_grid_desc_;
C1GridDesc c1_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel =
ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm,
ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
AGridDesc,
BGridDesc,
CGridDesc,
C0GridDesc,
C1GridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
if constexpr(!UseALocalBuffer &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
if constexpr(!UseBLocalBuffer)
{
if(!(arg.Conv_K_ % 8 == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
p_bias_grid,
p_add_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
static_cast<const BiasDataType*>(p_bias_grid),
static_cast<const AddDataType*>(p_add_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"movq (%[m_param]), %%rax\n" // p_a "movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b "movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr "movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda "movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb "movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
...@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n" ".if m_ABytes == 4\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n" "vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
...@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%rax, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n" "vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
...@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx(r9), lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vmovups_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"lea (%%rcx, %%rcx, 2), %%r9\n" "lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n" "lea (%%rax, %%r9), %%r8\n"
".endif\n" ".endif\n"
".else\n"
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rdx, %%rdx, 2), %%rdi\n"
"lea (%%rbx, %%rdi), %%r9\n"
".endif\n" ".endif\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
...@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea 4*m_ABytes(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%r9, %%rdx, 4), %%r9\n"
".else\n" ".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea m_ABytes(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%r9, %%rdx, 1), %%r9\n"
".else\n" ".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
else else
{ {
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m); ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
} }
} }
else else
...@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
else else
{ {
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m))); ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
} }
} }
}; };
...@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
{ {
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8); ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
} }
else else
{ {
...@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_cvtph_ps(_mm_loadu_si128( ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8))); reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
} }
else else
{ {
...@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4; p_a += 4;
} else{ } else{
p_a += Mr * 4; p_a += lda * 4;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4; p_b += ldb * 4;
}else{ }else{
p_b += 4 * 8; p_b += 4 * 8;
} }
...@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1; p_a += 1;
} else{ } else{
p_a += Mr * 1; p_a += lda * 1;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1; p_b += ldb * 1;
}else{ }else{
p_b += 1 * 8; p_b += 1 * 8;
} }
...@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"movq (%[m_param]), %%rax\n" // p_a "movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b "movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr "movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda "movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb "movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
...@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n" ".if m_ABytes == 4\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n" ".if (\\i_m == 0) || (\\i_m == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n" "vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
...@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n" ".if (\\i_m == 0) || (\\i_m == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n" "vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
...@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vmovups_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_Mr > 2\n" ".if m_Mr > 2\n"
"lea (%%rax, %%rcx, 2), %%r8\n" "lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n" ".endif\n"
".else\n"
"lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rbx, %%rdx, 2), %%rdi\n"
".endif\n" ".endif\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
...@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea 4*m_ABytes(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%rdi, %%rdx, 4), %%rdi\n"
".else\n" ".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea m_ABytes(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%rdi, %%rdx, 1), %%rdi\n"
".else\n" ".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
} }
else else
{ {
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m); ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
} }
} }
else else
...@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
} }
else else
{ {
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m))); ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
} }
} }
}; };
...@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
{ {
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8); ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
} }
else else
{ {
...@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_cvtph_ps(_mm_loadu_si128( ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8))); reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
} }
else else
{ {
...@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4; p_a += 4;
} else{ } else{
p_a += Mr * 4; p_a += lda * 4;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4; p_b += ldb * 4;
}else{ }else{
p_b += 4 * 8; p_b += 4 * 8;
} }
...@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1; p_a += 1;
} else{ } else{
p_a += Mr * 1; p_a += lda * 1;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1; p_b += ldb * 1;
}else{ }else{
p_b += 1 * 8; p_b += 1 * 8;
} }
......
...@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 ...@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
intptr_t src_offset; intptr_t src_offset;
}; };
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK(
const SrcDesc& src_desc,
const Index&,
const DstDesc&,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_k = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_n = src_slice_origin_idx[Number<1>{}];
src_offset = idx_k * GemmN + idx_n;
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
else
{
const ck::index_t k_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// k * n
index_t i_k_itr = k_per_block;
while(i_k_itr >= 8)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * n_per_block, p_src + 4 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * n_per_block, p_src + 5 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * n_per_block, p_src + 6 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * n_per_block, p_src + 7 * GemmN, n_per_block, element_op_);
i_k_itr -= 8;
p_dst += 8 * n_per_block;
p_src += 8 * GemmN;
}
if(i_k_itr & 4)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
p_dst += 4 * n_per_block;
p_src += 4 * GemmN;
}
if(i_k_itr & 2)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
p_dst += 2 * n_per_block;
p_src += 2 * GemmN;
}
if(i_k_itr & 1)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<0>{}];
ck::index_t move_n = src_slice_origin_step_idx[Number<1>{}];
src_offset += move_k * GemmN + move_n;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
) )
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
) )
add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk, ...@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
int max_threads = omp_get_max_threads(); int max_threads = omp_get_max_threads();
auto invoke_uk = [&](ck::cpu::ThreadwiseGemmParam& param, float* current_mat_c) { auto invoke_uk = [&](ck::cpu::ThreadwiseGemmParam& param, float* current_mat_c) {
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value) assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
assert(m % uk.ThreadMr == 0 && n == uk.ThreadNr); if constexpr(std::is_same<Row, ALayout>::value)
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
uk.Run(&param); param.p_a = mat_a + i_m * k;
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_c = p_c;
} }
} else
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
float* p_c_n = p_c; param.p_a = mat_a + i_m;
FloatB* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{
uk.Run(&param);
p_b_n += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c_n += uk.ThreadNr;
param.p_b = p_b_n;
param.p_c = p_c_n;
}
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
} }
}
else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m == uk.ThreadMr && n == uk.ThreadNr);
uk.Run(&param);
}
else
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatB* p_b = mat_b;
float* p_c = current_mat_c;
param.p_b = p_b;
param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr) for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{ {
if constexpr(std::is_same<Row, BLayout>::value)
{
param.p_b = mat_b + i_n;
}
else
{
param.p_b = mat_b + i_n * k;
}
param.p_c = current_mat_c + i_m * n + i_n;
uk.Run(&param); uk.Run(&param);
p_b += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c += uk.ThreadNr;
param.p_b = p_b;
param.p_c = p_c;
} }
} }
}; };
...@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk, ...@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
} }
// implement small ukernel on L1 // implement small ukernel on L1
template <typename FloatA, typename FloatB, typename ALayout, typename BLayout> template <typename FloatA,
typename FloatB,
typename ALayout,
typename BLayout,
typename thread_gemm_instance>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{ {
int max_threads = omp_get_max_threads(); int max_threads = omp_get_max_threads();
...@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
k); k);
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>; // using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>; // using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false; bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) { ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>; using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
if(m % uk_type::ThreadMr != 0 || n % uk_type::ThreadNr != 0) if(m % uk_type::ThreadMr != 0 || n % uk_type::ThreadNr != 0)
return; return;
if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value) || // if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value)
(n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value)) // ||
// only k is the fast changing dim of A/B can we do muldiplt m, n // (n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
return; // // only k is the fast changing dim of A/B can we do muldiplt m, n
// return;
if(found) if(found)
return; return;
...@@ -435,8 +402,21 @@ int main(int argc, char** argv) ...@@ -435,8 +402,21 @@ int main(int argc, char** argv)
omp_set_num_threads(1); omp_set_num_threads(1);
printf("max threads:%d\n", omp_get_max_threads()); printf("max threads:%d\n", omp_get_max_threads());
test_cpu_ukernel<AType, BType, Row, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_4x24_instances<Row, Row>>(
test_cpu_ukernel<AType, BType, Row, Col>(alpha, m, n, k); alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_4x24_instances<Row, Col>>(
test_cpu_ukernel<AType, BType, Col, Col>(alpha, m, n, k); alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_4x24_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_4x24_instances<Col, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_6x16_instances<Row, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_6x16_instances<Row, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_6x16_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_6x16_instances<Col, Col>>(
alpha, m, n, k);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment