Commit ddad386b authored by carlushuang's avatar carlushuang
Browse files

add cpu bias_relu_add example

parent 505194d7
...@@ -58,3 +58,4 @@ add_subdirectory(16_gemm_reduce) ...@@ -58,3 +58,4 @@ add_subdirectory(16_gemm_reduce)
add_subdirectory(18_batched_gemm_reduce) add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(cpu_01_conv2d_fwd) add_subdirectory(cpu_01_conv2d_fwd)
add_subdirectory(cpu_02_conv2d_fwd_bias_relu_add)
add_example_executable(example_cpu_conv2d_fwd_bias_relu_add cpu_conv2d_fwd_bias_relu_add.cpp)
target_link_libraries(example_cpu_conv2d_fwd_bias_relu_add PRIVATE device_conv2d_fwd_bias_activation_add_cpu_instance)
set_target_properties(example_cpu_conv2d_fwd_bias_relu_add PROPERTIES LINK_FLAGS "${OMP_LINK_FLAG}")
target_link_libraries(example_cpu_conv2d_fwd_bias_relu_add PRIVATE "${OMP_LIBRARY}")
target_compile_options(example_cpu_conv2d_fwd_bias_relu_add PRIVATE "${OMP_CXX_FLAG}")
#include <sstream>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd_bias_activation_add.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::cpu::element_wise::AddReluAdd;
template <typename T>
static bool
check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0)
{
int error_count = 0;
float max_diff = 1e-5;
double square_difference = .0;
double mag1 = .0;
double mag2 = .0;
for(int i = 0; i < ref.mData.size(); ++i)
{
double ri = (double)ref.mData[i];
double pi = (double)result.mData[i];
double d = ri - pi;
if(per_pixel_check)
{
if(max_diff < std::abs(d))
{
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
d);
}
}
square_difference += d * d;
if(std::abs(mag1) < std::abs(ri))
mag1 = ri;
if(std::abs(mag2) < std::abs(pi))
mag2 = pi;
}
double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
if(computed_nrms >= nrms)
printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n",
computed_nrms,
mag1,
mag2,
nrms);
return computed_nrms < nrms && error_count == 0;
}
float calculate_gflops() {}
template <typename T>
void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = K / 8;
ck::index_t row = 8;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 2;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
data_type = 0;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
}
else if(argc == 18)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
conv_stride_h = std::stoi(argv[10]);
conv_stride_w = std::stoi(argv[11]);
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
}
else
{
printf("arg1: data type (0=fp32, 1=fp16)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ReferenceConvFwdInstance =
ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t H_,
std::size_t W_) {
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
// bias: assume contiguous 1d vector
Tensor<OutDataType> bias(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(K)})));
// residual: assume same layout as output tensor
Tensor<OutDataType> residual(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "residual: " << residual.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
int per_pixel_check = 0;
switch(init_method)
{
case 0:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
bias.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
residual.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
bias.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
residual.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
per_pixel_check = 1;
break;
case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
residual.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
bias.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
residual.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU resi_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
bias_device_buf.ToDevice(bias.mData.data());
resi_device_buf.ToDevice(residual.mData.data());
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
bias,
residual,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>;
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
conv_ptrs);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
}
#endif
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(resi_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{}, 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result,
out_n_k_ho_wo_device_result,
1e-6,
per_pixel_check))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
#ifndef DEVICE_CONV_FWD_CPU_HPP #ifndef DEVICE_CONV_FWD_CPU_HPP
#define DEVICE_CONV_FWD_CPU_HPP #define DEVICE_CONV_FWD_CPU_HPP
#include <iostream> #include <iostream>
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvFwd : public BaseOperator struct DeviceConvFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
const void* p_wei, const void* p_wei,
void* p_out, void* p_out,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0; OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvFwdPtr = std::unique_ptr< using DeviceConvFwdPtr = std::unique_ptr<
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device template <typename InElementwiseOperation,
} // namespace cpu typename WeiElementwiseOperation,
} // namespace tensor_operation typename OutElementwiseOperation>
} // namespace ck struct DeviceConvFwdBiasActivationAdd : public BaseOperator
#endif {
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
const void* p_wei,
void* p_out,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvFwdBiasActivationAddPtr =
std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>>;
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename BiasDataType,
typename AddDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp =
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using C0DataType = BiasDataType;
using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc,
make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return wei_gemm_n0_k_n1_grid_desc;
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
if constexpr(BiasAlongGemmM)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(
in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ADataType,
ADataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
BDataType,
BDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
BDataType, // WeiDataType,
CDataType, // OutDataType,
C0DataType, // C0DataType
C1DataType, // C1DataType
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
C0GridDesc, // C0GridDesc,
C1GridDesc, // C1GridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
p_c0_grid_{p_bias_grid},
p_c1_grid_{p_add_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
c0_grid_desc_{},
c1_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
GetGemmN(K));
c1_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
C0GridDesc c0_grid_desc_;
C1GridDesc c1_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel =
ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm,
ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
AGridDesc,
BGridDesc,
CGridDesc,
C0GridDesc,
C1GridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
p_bias_grid,
p_add_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
static_cast<const BiasDataType*>(p_bias_grid),
static_cast<const AddDataType*>(p_add_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -148,6 +148,27 @@ struct AddReluAdd ...@@ -148,6 +148,27 @@ struct AddReluAdd
y = _mm256_add_ps(b, x2); y = _mm256_add_ps(b, x2);
} }
float Apply(const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
return b + x2;
}
float4_t Apply(const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
return _mm_add_ps(b, x2);
}
float8_t Apply(const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
return _mm256_add_ps(b, x2);
}
static constexpr char* Name() { return "AddReluAdd"; } static constexpr char* Name() { return "AddReluAdd"; }
}; };
......
#ifndef CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <utility>
#include <unistd.h>
#include <omp.h>
#include <pthread.h>
namespace ck {
namespace cpu {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
typename C0GridDesc,
typename C1GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
p_c1_grid,
a_grid_desc,
b_grid_desc,
c_grid_desc,
c0_grid_desc,
c1_grid_desc,
a_element_op,
b_element_op,
c_element_op);
}
template <typename FloatA,
typename FloatB,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
typename C0GridDesc,
typename C1GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct GridwiseGemmBiasActivationAddAvx2_MxN
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit
static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
return a_block_desc_m_k;
}
else
{
// A : K, M
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(k_per_blk,
math::integer_least_multiple(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m;
}
}
static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
auto b_block_desc_k_n =
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
return b_block_desc_k_n;
}
else
{
// B : N/8, K, N8
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1;
}
}
static auto GetCBlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t n_per_blk,
const CGridDesc& c_grid_desc)
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
else
return c_grid_desc;
}
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(i_m, i_k);
}
else
{
// A : K, M
return ck::make_multi_index(i_k, i_m);
}
}
static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n)
{
// i_n should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(i_k, i_n);
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
i_k,
i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
{
return ck::make_multi_index(i_m, i_n);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
{
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool is_valid = true;
const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer)
{
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
is_valid &= false;
}
else
{
// TODO: need check c grid is simple transform?
if(GemmN % 8 != 0)
is_valid &= false;
}
return is_valid;
}
static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
ck::index_t k_per_block = KPerBlock;
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1);
constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto c0_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatC0*>(p_c0_grid), c0_grid_desc.GetElementSpaceSize());
auto c1_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatC1*>(p_c1_grid), c1_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
int total_threads = omp_get_max_threads();
#if 0
if(total_threads > 1){
#pragma omp parallel
{
int tid = omp_get_thread_num();
cpu_set_t set;
CPU_ZERO(&set);
CPU_SET(tid, &set);
if (sched_setaffinity(0, sizeof(set), &set) == -1) {
throw std::runtime_error("wrong! fail to set thread affinity");
}
}
}
#endif
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
const ck::index_t grid_size = grid_m * grid_n;
const ck::index_t grids_per_thread =
math::integer_divide_ceil(grid_size, total_threads);
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel
{
auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA));
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
b_block_mem.mMemSize / sizeof(FloatB));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
const ck::index_t tid = omp_get_thread_num();
for(ck::index_t i_gpt = 0; i_gpt < grids_per_thread; i_gpt++)
{
ck::index_t gid = i_gpt * total_threads + tid;
if(gid >= grid_size)
break;
ck::index_t i_mc = (gid / grid_n) * m_per_block;
ck::index_t i_nc = (gid % grid_n) * n_per_block;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_grid_desc,
c_grid_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
}
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetASliceLength(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBSliceLength(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc,
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>(),
i_kc != 0);
if((i_kc + k_per_block) < GemmK)
{
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
}
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_grid_desc,
c_grid_buf,
GetCSliceLength(mc_size, nc_size));
}
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
// only parallel in gemm m dim
#pragma omp parallel
{
auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA));
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
b_block_mem.mMemSize / sizeof(FloatB));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
const ck::index_t tid = omp_get_thread_num();
for(ck::index_t i_gmpt = 0; i_gmpt < grid_m_per_thread; i_gmpt++)
{
ck::index_t i_mc = (i_gmpt * total_threads + tid) * m_per_block;
if(i_mc >= GemmM)
break;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetASliceLength(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
{
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBSliceLength(kc_size, nc_size));
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc,
GetCIndex(i_mc, i_nc));
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_grid_desc,
c_grid_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
}
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc,
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>(),
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_grid_desc,
c_grid_buf,
GetCSliceLength(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_grid_desc,
c_grid_buf,
GetCSliceLength(mc_size, nc_size));
}
}
}
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
}
}
}
}
}
};
} // namespace cpu
} // namespace ck
#endif
...@@ -62,6 +62,161 @@ void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const Elemen ...@@ -62,6 +62,161 @@ void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const Elemen
} }
} }
template <typename ElementwiseOp>
void memcpy32_avx2_with_extra_2src(void* dst,
const void* src,
const void* src1,
const void* src2,
const ck::index_t n,
const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
const float* p_src1 = reinterpret_cast<const float*>(src1);
const float* p_src2 = reinterpret_cast<const float*>(src2);
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0,
element_op.Apply(_mm256_loadu_ps(p_src + 0),
_mm256_loadu_ps(p_src1 + 0),
_mm256_loadu_ps(p_src2 + 0)));
_mm256_storeu_ps(p_dst + 8,
element_op.Apply(_mm256_loadu_ps(p_src + 8),
_mm256_loadu_ps(p_src1 + 8),
_mm256_loadu_ps(p_src2 + 8)));
p_dst += 16;
p_src += 16;
p_src1 += 16;
p_src2 += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst,
element_op.Apply(_mm256_loadu_ps(p_src),
_mm256_loadu_ps(p_src1),
_mm256_loadu_ps(p_src2)));
p_dst += 8;
p_src += 8;
p_src1 += 8;
p_src2 += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(
p_dst,
element_op.Apply(_mm_loadu_ps(p_src), _mm_loadu_ps(p_src1), _mm_loadu_ps(p_src2)));
p_dst += 4;
p_src += 4;
p_src1 += 4;
p_src2 += 4;
}
if(i_n & 2)
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
__m128i s = _mm_loadu_si64(p_src);
__m128 v = element_op.Apply(*reinterpret_cast<__m128*>(&s));
__m128i s1 = _mm_loadu_si64(p_src1);
__m128 v1 = element_op.Apply(*reinterpret_cast<__m128*>(&s1));
__m128i s2 = _mm_loadu_si64(p_src2);
__m128 v2 = element_op.Apply(*reinterpret_cast<__m128*>(&s2));
_mm_storeu_si64(p_dst,
*reinterpret_cast<__m128i*>(&v),
*reinterpret_cast<__m128i*>(&v1),
*reinterpret_cast<__m128i*>(&v2));
#else
_mm_storeu_si64(p_dst,
element_op.Apply(
_mm_loadu_si64(p_src), _mm_loadu_si64(p_src1), _mm_loadu_si64(p_src2)));
#endif
p_dst += 2;
p_src += 2;
p_src1 += 2;
p_src2 += 2;
}
if(i_n & 1)
{
*p_dst = element_op.Apply(*p_src, *p_src1, *p_src2);
}
}
template <typename ElementwiseOp>
void memcpy32_avx2_with_extra_2src(void* dst,
const void* src,
float v_src1,
const void* src2,
const ck::index_t n,
const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
const float* p_src2 = reinterpret_cast<const float*>(src2);
__m256 ymm_src1 = _mm256_set1_ps(*reinterpret_cast<const float*>(&v_src1));
__m128 xmm_src1 = _mm_set1_ps(*reinterpret_cast<const float*>(&v_src1));
while(i_n >= 16)
{
_mm256_storeu_ps(
p_dst + 0,
element_op.Apply(_mm256_loadu_ps(p_src + 0), ymm_src1, _mm256_loadu_ps(p_src2 + 0)));
_mm256_storeu_ps(
p_dst + 8,
element_op.Apply(_mm256_loadu_ps(p_src + 8), ymm_src1, _mm256_loadu_ps(p_src2 + 8)));
p_dst += 16;
p_src += 16;
p_src2 += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(
p_dst, element_op.Apply(_mm256_loadu_ps(p_src), ymm_src1, _mm256_loadu_ps(p_src2)));
p_dst += 8;
p_src += 8;
p_src2 += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src), xmm_src1, _mm_loadu_ps(p_src2)));
p_dst += 4;
p_src += 4;
p_src2 += 4;
}
if(i_n & 2)
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
__m128i s = _mm_loadu_si64(p_src);
__m128 v = element_op.Apply(*reinterpret_cast<__m128*>(&s));
__m128i s2 = _mm_loadu_si64(p_src2);
__m128 v2 = element_op.Apply(*reinterpret_cast<__m128*>(&s2));
_mm_storeu_si64(p_dst,
*reinterpret_cast<__m128i*>(&v),
*reinterpret_cast<__m128i*>(&xmm_src1),
*reinterpret_cast<__m128i*>(&v2));
#else
_mm_storeu_si64(p_dst,
element_op.Apply(_mm_loadu_si64(p_src), xmm_src1, _mm_loadu_si64(p_src2)));
#endif
p_dst += 2;
p_src += 2;
p_src2 += 2;
}
if(i_n & 1)
{
*p_dst = element_op.Apply(*p_src, v_src1, *p_src2);
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
...@@ -1361,6 +1516,672 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -1361,6 +1516,672 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
intptr_t dst_offset; intptr_t dst_offset;
}; };
template <typename SrcData,
typename Src1Data, // for Bias, per dimension
typename Src2Data, // for Residual, per pixel
typename DstData,
typename SrcDesc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
bool Src1AlongDim0> // if true, src1 has dim along M, false, src1 has dim along N
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN(
const SrcDesc& src_desc,
const Index&,
const DstDesc& dst_desc,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
src_offset = 0;
src1_offset = 1;
src2_offset = 2;
dst_offset = 0;
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(BypassTransfer)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
}
}
void SetSrc1SliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(Src1AlongDim0)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
// auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src1_offset = i_src_gemm_m;
}
else
{
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src1_offset = i_src_gemm_n;
}
}
void SetSrc2SliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src2_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
}
void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
{
i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}];
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}];
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer,
typename Src1Buffer,
typename Src2Buffer,
typename DstBuffer,
typename SliceLengths>
void RunRead(const SrcDesc&,
SrcBuffer& src_buf,
const Src1Desc&,
Src1Buffer&,
const Src2Desc&,
Src2Buffer&,
const DstDesc&,
DstBuffer& dst_buf,
const SliceLengths&)
{
if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
}
template <typename SrcBuffer,
typename Src1Buffer,
typename Src2Buffer,
typename DstBuffer,
typename SliceLengths>
void RunWrite(const SrcDesc& src_desc,
SrcBuffer& src_buf,
const Src1Desc& src1_desc,
Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
// src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
const float* p_src1 =
reinterpret_cast<const float*>(src1_buf.p_data_) + src1_offset;
const float* p_src2 =
reinterpret_cast<const float*>(src2_buf.p_data_) + src2_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d,
// dst_offset:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block, dst_offset);fflush(stdout);
// standard 8-4-2-1 pattern
if constexpr(Src1AlongDim0)
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
*(p_src1 + 2),
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
*(p_src1 + 3),
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 4 * DstGemmN,
p_dst + 4 * DstGemmN,
*(p_src1 + 4),
p_src2 + 4 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 5 * DstGemmN,
p_dst + 5 * DstGemmN,
*(p_src1 + 5),
p_src2 + 5 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 6 * DstGemmN,
p_dst + 6 * DstGemmN,
*(p_src1 + 6),
p_src2 + 6 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 7 * DstGemmN,
p_dst + 7 * DstGemmN,
*(p_src1 + 7),
p_src2 + 7 * DstGemmN,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src1 += 8;
p_src2 += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
*(p_src1 + 2),
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
*(p_src1 + 3),
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src1 += 4;
p_src2 += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src1 += 2;
p_src2 += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
}
}
else
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
p_src1,
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
p_src1,
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 4 * DstGemmN,
p_dst + 4 * DstGemmN,
p_src1,
p_src2 + 4 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 5 * DstGemmN,
p_dst + 5 * DstGemmN,
p_src1,
p_src2 + 5 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 6 * DstGemmN,
p_dst + 6 * DstGemmN,
p_src1,
p_src2 + 6 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 7 * DstGemmN,
p_dst + 7 * DstGemmN,
p_src1,
p_src2 + 7 * DstGemmN,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src2 += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
p_src1,
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
p_src1,
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src2 += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src2 += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
}
}
}
}
else
{
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
const float* p_src1 = reinterpret_cast<const float*>(src1_buf.p_data_) + src1_offset;
const float* p_src2 = reinterpret_cast<const float*>(src2_buf.p_data_) + src2_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
if constexpr(Src1AlongDim0)
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
*(p_src1 + 2),
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
*(p_src1 + 3),
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 4 * DstGemmN,
p_src + 4 * n_per_block,
*(p_src1 + 4),
p_src2 + 4 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 5 * DstGemmN,
p_src + 5 * n_per_block,
*(p_src1 + 5),
p_src2 + 5 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 6 * DstGemmN,
p_src + 6 * n_per_block,
*(p_src1 + 6),
p_src2 + 6 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 7 * DstGemmN,
p_src + 7 * n_per_block,
*(p_src1 + 7),
p_src2 + 7 * DstGemmN,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
p_src1 += 8;
p_src2 += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
*(p_src1 + 2),
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
*(p_src1 + 3),
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
p_src1 += 4;
p_src2 += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
p_src1 += 2;
p_src2 += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
}
}
else
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
p_src1,
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
p_src1,
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 4 * DstGemmN,
p_src + 4 * n_per_block,
p_src1,
p_src2 + 4 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 5 * DstGemmN,
p_src + 5 * n_per_block,
p_src1,
p_src2 + 5 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 6 * DstGemmN,
p_src + 6 * n_per_block,
p_src1,
p_src2 + 6 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 7 * DstGemmN,
p_src + 7 * n_per_block,
p_src1,
p_src2 + 7 * DstGemmN,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
p_src2 += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
p_src1,
p_src2 + 2 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
p_src1,
p_src2 + 3 * DstGemmN,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
p_src2 += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
p_src2 + 1 * DstGemmN,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
p_src2 += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_2src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
p_src2 + 0 * DstGemmN,
current_n,
element_op_);
}
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n;
ck::index_t DstGemmM;
ck::index_t DstGemmN;
intptr_t src_offset;
intptr_t src1_offset;
intptr_t src2_offset;
intptr_t dst_offset;
};
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
......
...@@ -22,3 +22,4 @@ function(add_instance_library INSTANCE_NAME) ...@@ -22,3 +22,4 @@ function(add_instance_library INSTANCE_NAME)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv2d_fwd_bias_activation_add)
# device_conv2d_fwd_bias_activation_add_cpu_instance
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
)
add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_activation_add_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(device_conv2d_fwd_bias_activation_add_cpu_instance PRIVATE "${OMP_LIBRARY}")
target_compile_options(device_conv2d_fwd_bias_activation_add_cpu_instance PRIVATE "${OMP_CXX_FLAG}")
install(TARGETS device_conv2d_fwd_bias_activation_add_cpu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_bias_activation_add_cpu_instance)
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, true, true, true, false),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment