"git@developer.sourcefind.cn:change/sglang.git" did not exist on "75f4ccb7ddea2fd1abaa6475855da141b6c63980"
Commit 5db79de0 authored by carlushuang's avatar carlushuang
Browse files

add a direct bias-relu-add implementation

parent 5024f317
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#define TEST_FUSION_BIAS_RELU 1 #define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS 2 #define TEST_FUSION_BIAS 2
#define TEST_FUSION_BIAS_ADD_RELU 3 #define TEST_FUSION_BIAS_ADD_RELU 3
#define TEST_FUSION TEST_FUSION_BIAS_ADD_RELU #define TEST_FUSION TEST_FUSION_BIAS_RELU_ADD
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
...@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt( ...@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddAddRelu>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddAddRelu>>&
instances); instances);
// ------------------ direct-conv nhwc-kcyxk8-nhwk
void add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -623,6 +628,8 @@ int main(int argc, char* argv[]) ...@@ -623,6 +628,8 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
ck::tensor_operation::cpu::device::device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
#endif #endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1) if(omp_get_max_threads() > 1)
......
#ifndef DEVICE_CONV2D_DIRECT_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXCK8_NHWK_HPP
#define DEVICE_CONV2D_DIRECT_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXCK8_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include <memory>
#include <vector>
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/tensor_operation/cpu/device/device_base_cpu.hpp"
#include "ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp"
#include "ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/cpu/grid/gridwise_direct_conv_bias_activation_add_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename BiasDataType,
typename AddDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM>
struct DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp =
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using C0DataType = BiasDataType;
using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
}
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8));
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
const auto out_gemm_m_n_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
}
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{
if constexpr(BiasAlongGemmM)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t X = filter_spatial_lengths[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
ck::index_t C,
ck::index_t gemm_m,
ck::index_t gemm_k,
ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
else
{
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
}
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
return N * std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
{
return C * std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
1,
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
const auto in_gemm_m_k_grid_desc =
GetInputTensorDescriptor<NumDimSpatial>(N,
C,
GemmM,
GemmK,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// B:
const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
// C:
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(
in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc);
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ADataType,
ADataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
!UseALocalBuffer,
ConvForwardSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
BDataType,
BDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
!UseBLocalBuffer,
ConvForwardSpecialization>;
static constexpr auto GetCThreadwiseCopy()
{
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseDirectConvNHWCBiasActivationAddAvx2<
ADataType, // InDataType,
BDataType, // WeiDataType,
CDataType, // OutDataType,
C0DataType, // C0DataType
C1DataType, // C1DataType
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
C0GridDesc, // C0GridDesc,
C1GridDesc, // C1GridDesc,
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
p_c0_grid_{p_bias_grid},
p_c1_grid_{p_add_grid},
a_grid_desc_{},
b_grid_desc_{},
c_grid_desc_{},
c0_grid_desc_{},
c1_grid_desc_{},
a_element_op_{in_element_op},
b_element_op_{wei_element_op},
c_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
output_spatial_lengths_{output_spatial_lengths},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1];
c_grid_desc_ = descs[I2];
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
GetGemmN(K));
c1_grid_desc_ = descs[I2];
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
C0GridDesc c0_grid_desc_;
C1GridDesc c1_grid_desc_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<index_t> input_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> conv_filter_dilations_;
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
// memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel =
ck::cpu::kernel_direct_conv_nhwc_bias_activation_add_avx_mxn<GridwiseGemm,
ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
AGridDesc,
BGridDesc,
CGridDesc,
C0GridDesc,
C1GridDesc,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.Conv_N_,
arg.Conv_K_,
arg.Conv_C_,
arg.input_spatial_lengths_,
arg.filter_spatial_lengths_,
arg.output_spatial_lengths_,
arg.conv_filter_strides_,
arg.conv_filter_dilations_,
arg.input_left_pads_,
arg.input_right_pads_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
// memset(arg.p_c_grid_, 0xfe, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.c0_grid_desc_,
arg.c1_grid_desc_,
arg.Conv_N_,
arg.Conv_K_,
arg.Conv_C_,
arg.input_spatial_lengths_,
arg.filter_spatial_lengths_,
arg.output_spatial_lengths_,
arg.conv_filter_strides_,
arg.conv_filter_dilations_,
arg.input_left_pads_,
arg.input_right_pads_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
// check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{
return false;
}
}
// if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
// ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
// ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
// {
// if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
// return false;
// }
if(!(arg.Conv_K_ % 8 == 0))
return false;
// if constexpr(!UseALocalBuffer &&
// ConvForwardSpecialization !=
// ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
// {
// // TODO: We can support this in the future, as long as figure out how to express
// tensor
// // transform
// return false;
// }
// Gridwise GEMM size
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
const AddDataType* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
p_bias_grid,
p_add_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
void* p_out_grid,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid),
static_cast<const BiasDataType*>(p_bias_grid),
static_cast<const AddDataType*>(p_add_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DDirectFwd_BBAA_vx2_NHWC_KYXCK8"
// <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
// <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
// << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DIRECT_CONV_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_DIRECT_CONV_BIAS_ACTIVATION_ADD_AVX2_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "ck/utility/dynamic_buffer_cpu.hpp"
#include "ck/utility/envvar.hpp"
#include <utility>
#include <unistd.h>
#include <omp.h>
#include <pthread.h>
namespace ck {
namespace cpu {
template <typename GridwiseDirectConv,
typename FloatA,
typename FloatB,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
typename C0GridDesc,
typename C1GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void kernel_direct_conv_nhwc_bias_activation_add_avx_mxn(
const GridwiseDirectConv& gridwise_direct_conv,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
gridwise_direct_conv.Run(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
p_c1_grid,
a_grid_desc,
b_grid_desc,
c_grid_desc,
c0_grid_desc,
c1_grid_desc,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
c_element_op);
}
template <typename FloatA,
typename FloatB,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
typename C0GridDesc,
typename C1GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct GridwiseDirectConvNHWCBiasActivationAddAvx2
{
ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit
GridwiseDirectConvNHWCBiasActivationAddAvx2(
const ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable_)
: dynamic_tunable(dynamic_tunable_)
{
}
static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc)
{
if constexpr(UseALocalBuffer)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
return a_block_desc_m_k;
}
else
{
// A : K, M
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(k_per_blk,
math::integer_least_multiple(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m;
}
}
else
{
return a_grid_desc;
}
}
static auto GetBBlockDescriptor(const ck::index_t k_per_blk,
const ck::index_t n_per_blk,
const BGridDesc& b_grid_desc)
{
if constexpr(UseBLocalBuffer)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
auto b_block_desc_k_n =
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
return b_block_desc_k_n;
}
else
{
// B : N/8, K, N8
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(
make_tuple(math::integer_divide_ceil(
n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1;
}
}
else
{
return b_grid_desc;
}
}
static auto GetCBlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t n_per_blk,
const CGridDesc& c_grid_desc)
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
else
return c_grid_desc;
}
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(i_m, i_k);
}
else
{
// A : K, M
return ck::make_multi_index(i_k, i_m);
}
}
static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n)
{
// i_n should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(i_k, i_n);
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
i_k,
i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
{
return ck::make_multi_index(i_m, i_n);
}
bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
{
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool is_valid = true;
const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer)
{
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::
ConvolutionForwardBlockLoopOverSpecialization_t::LoopOver_MKN &&
dynamic_tunable.n_per_block < GemmN)
is_valid &= false;
}
else
{
// TODO: need check c grid is simple transform?
if(GemmN % 8 != 0)
is_valid &= false;
}
return is_valid;
}
static intptr_t
GetBBlockStartOffset(const BGridDesc& b_grid_desc, const intptr_t i_k, const intptr_t i_n)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] +
i_k * 8;
}
}
static intptr_t
GetCBlockStartOffset(const CGridDesc& c_grid_desc, const intptr_t i_m, const intptr_t i_n)
{
return i_m * c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
static intptr_t GetBLeadingElement(const BGridDesc& b_grid_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_grid_desc.GetLength(Number<1>{}) * b_grid_desc.GetLength(Number<2>{});
}
}
static intptr_t GetCLeadingElement(const CGridDesc& c_grid_desc)
{
return c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const
{
const ck::index_t m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const ck::index_t n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
const ck::index_t k_per_thread = C;
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1);
const intptr_t Hi = input_spatial_lengths[0];
const intptr_t Wi = input_spatial_lengths[1];
const intptr_t Ho = output_spatial_lengths[0];
const intptr_t Wo = output_spatial_lengths[1];
const intptr_t Y = filter_spatial_lengths[0];
const intptr_t X = filter_spatial_lengths[1];
const intptr_t Sy = conv_filter_strides[0];
const intptr_t Sx = conv_filter_strides[1];
const intptr_t Dy = conv_filter_dilations[0];
const intptr_t Dx = conv_filter_dilations[1];
const intptr_t Py = input_left_pads[0];
const intptr_t Px = input_left_pads[1];
const intptr_t X_Dx = X * Dx;
// const index_t Y_Dy = Y * Dy;
// const index_t InRightPadH = input_right_pads[0];
// const index_t InRightPadW = input_right_pads[1];
constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
const_cast<FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
const_cast<FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto c0_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatC0*>(p_c0_grid), c0_grid_desc.GetElementSpaceSize());
auto c1_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatC1*>(p_c1_grid), c1_grid_desc.GetElementSpaceSize());
int total_threads = omp_get_max_threads();
if(total_threads > 1 && ck::getenv_int("CK_CPU_BIND_CORE", 0) != 0)
{
#pragma omp parallel
{
int tid = omp_get_thread_num();
cpu_set_t set;
CPU_ZERO(&set);
CPU_SET(tid, &set);
if(sched_setaffinity(0, sizeof(set), &set) == -1)
{
throw std::runtime_error("wrong! fail to set thread affinity");
}
}
}
auto devide_thread = [](ck::index_t n_, ck::index_t length_, ck::index_t factor_) {
ck::index_t t_ = n_;
while(t_ > length_ && (t_ % factor_ == 0))
{
t_ /= factor_;
}
return t_;
};
const intptr_t num_works_n = N;
const intptr_t num_works_ho = Ho;
// const intptr_t num_works_nho = N * Ho;
const intptr_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
const intptr_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
auto distribute_num_threads_n_ho_wo_k = [&](ck::index_t& num_threads_n_,
ck::index_t& num_threads_ho_,
ck::index_t& num_threads_wo_,
ck::index_t& num_threads_k_) {
// TODO: only consider multiply of 2 to divide threads
ck::index_t num_threads = total_threads;
num_threads_n_ = devide_thread(num_threads, num_works_n, 2);
num_threads = num_threads / num_threads_n_;
num_threads_ho_ = devide_thread(num_threads, num_works_ho, 2);
num_threads = num_threads / num_threads_ho_;
num_threads_wo_ = devide_thread(num_threads, num_works_wo, 2);
num_threads = num_threads / num_threads_wo_;
num_threads_k_ = devide_thread(num_threads, num_works_k, 2);
// num_threads = num_threads / num_threads_k_;
};
ck::index_t num_threads_n;
ck::index_t num_threads_ho;
ck::index_t num_threads_wo;
ck::index_t num_threads_k;
distribute_num_threads_n_ho_wo_k(
num_threads_n, num_threads_ho, num_threads_wo, num_threads_k);
const ck::index_t num_works_n_per_thread =
math::integer_divide_ceil(num_works_n, num_threads_n);
const ck::index_t num_works_ho_per_thread =
math::integer_divide_ceil(num_works_ho, num_threads_ho);
const ck::index_t num_works_wo_per_thread =
math::integer_divide_ceil(num_works_wo, num_threads_wo);
const ck::index_t num_works_k_per_thread =
math::integer_divide_ceil(num_works_k, num_threads_k);
// printf("num_threads_nho:%d, num_threads_wo:%d, num_threads_k:%d |
// num_works_nho_per_thread:%d, num_works_wo_per_thread:%d, num_works_k_per_thread:%d\n",
// num_threads_nho, num_threads_wo, num_threads_k, num_works_nho_per_thread,
// num_works_wo_per_thread, num_works_k_per_thread); fflush(stdout);
if((X - 1) * Dx + 1 <= Px || (Y - 1) * Dy + 1 <= Py)
{
// padding zero case, outpout will have zero due to upsampling
// TODO: This is ugly and slow
ck::cpu::avx2_util::memset32_avx2(&c_grid_buf.p_data_[0], 0, N * Ho * Wo * K);
// printf("___ clear\n");
}
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK)
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU a_block_mem(
UseALocalBuffer ? m_per_thread * k_per_thread * sizeof(FloatA) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(
UseBLocalBuffer ? k_per_thread * n_per_thread * sizeof(FloatB) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_thread * n_per_thread * sizeof(FloatC)) : 0,
MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
: const_cast<FloatA*>(p_a_grid),
UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
: a_grid_desc.GetElementSpaceSize());
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
: const_cast<FloatB*>(p_b_grid),
UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
: b_grid_desc.GetElementSpaceSize());
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
ck::index_t tid = omp_get_thread_num();
const ck::index_t tid_n = tid % num_threads_n;
tid /= num_threads_n;
const ck::index_t tid_ho = tid % num_threads_ho;
tid /= num_threads_ho;
const ck::index_t tid_wo = tid % num_threads_wo;
tid /= num_threads_wo;
const ck::index_t tid_k = tid;
ck::cpu::ThreadwiseGemmParam param;
// param.Kr = k_per_block;
param.lda = Sx * C * sizeof(FloatA);
param.ldb = GetBLeadingElement(b_grid_desc) * sizeof(FloatB);
param.ldc = GetCLeadingElement(c_grid_desc) * sizeof(FloatC);
param.alpha = 1.0f; // TODO
param.Kr = C;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// ck::index_t i_nho = tid_nho * num_works_nho_per_thread;
// ck::index_t i_ho = i_nho % Ho;
// ck::index_t i_n = i_nho / Ho;
// auto accumulate_n_ho = [&]() {
// i_ho++;
// if(i_ho >= Wo)
// {
// i_ho = 0;
// i_n++;
// }
// };
for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1)
{
for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
{
// for input
intptr_t i_hi_no_y = i_ho * Sy - Py;
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += m_per_thread)
{
intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
intptr_t i_wi_no_x = i_wo * Sx - Px;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout);
for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread)
{
intptr_t i_dx = 0;
intptr_t i_dy = 0;
bool accmulate_c = false;
intptr_t current_k_size =
ck::math::min(K - i_k, (intptr_t)n_per_thread);
auto accumulate_dy_dx = [&]() {
i_dx += Dx;
if(i_dx >= X_Dx)
{
i_dx = 0;
i_dy += Dy;
}
};
for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx())
{
intptr_t current_i_wo = i_wo;
intptr_t i_hi = i_hi_no_y + i_dy;
if(i_hi < 0 || i_hi >= Hi)
continue;
intptr_t i_wi = i_wi_no_x + i_dx;
intptr_t current_wo_size = current_wo_size_no_dx;
intptr_t pad_wo_size = 0; // when left pad, we may never have
// a chance to clear zero (like
// padding) we need to manually clear that
if(i_wi < 0)
{
intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number
intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be
// possitive
current_wo_size -= steps_wo_turn_possitive;
if(current_wo_size <= 0)
continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating,
// no need to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi)
continue;
// shrink right wi/wo
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size = (Wi - 1 - i_wi) / Sx +
1; // NOTE: this be careful why here
// should be compute like this.
if(current_wo_size <= 0)
continue;
}
param.accmulate_c = accmulate_c ? 1 : 0;
accmulate_c = true;
intptr_t current_input_offset =
i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C;
if(pad_wo_size != 0)
{
for(intptr_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
i_wo_pad++)
{
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc,
(i_n * Ho + i_ho) * Wo + i_wo_pad,
i_k);
// printf("pad_wo_size:%d, current_k_block_size:%d,
// clear offset_c:%d\n",
// pad_wo_size, current_k_size,
// offset_c);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(
&c_block_buf.p_data_[offset_c], 0, current_k_size);
}
}
const intptr_t offset_a = current_input_offset;
const intptr_t offset_b =
GetBBlockStartOffset(b_grid_desc, i_yxc, i_k);
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo + current_i_wo, i_k);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d,
// ldb:%d, ldc:%d, acc:%d\n",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_dy, i_k, i_ho, current_i_wo, current_wo_size,
// current_k_size, i_nho, param.lda / sizeof(FloatA),
// param.ldb / sizeof(FloatB), param.ldc / sizeof(FloatC),
// param.accmulate_c); fflush(stdout);
param.p_a = &a_block_buf.p_data_[offset_a];
param.p_b = &b_block_buf.p_data_[offset_b];
param.p_c = &c_block_buf.p_data_[offset_c];
ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size);
}
}
}
// thread block wise fusion
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += 1)
{
const intptr_t n_size =
ck::math::min(K - tid_k * num_works_k_per_thread * n_per_thread,
num_works_k_per_thread * n_per_thread);
if constexpr(CThreadwiseCopy::FuseBias && CThreadwiseCopy::FuseAdd)
{
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo + i_wo, 0);
const intptr_t offset_c0 = 0;
avx2_util::memcpy32_avx2_with_extra_2src(
&c_block_buf.p_data_[offset_c],
&c_block_buf.p_data_[offset_c],
&c0_grid_buf.p_data_[offset_c0],
&c1_grid_buf.p_data_[offset_c],
n_size,
c_element_op);
}
else
{
}
}
}
}
}
}
else if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MKN)
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU a_block_mem(
UseALocalBuffer ? m_per_thread * k_per_thread * sizeof(FloatA) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(
UseBLocalBuffer ? k_per_thread * n_per_thread * sizeof(FloatB) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_thread * n_per_thread * sizeof(FloatC)) : 0,
MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
: const_cast<FloatA*>(p_a_grid),
UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
: a_grid_desc.GetElementSpaceSize());
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
: const_cast<FloatB*>(p_b_grid),
UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
: b_grid_desc.GetElementSpaceSize());
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
ck::cpu::ThreadwiseGemmParam param;
// param.Kr = k_per_block;
param.lda = Sx * C * sizeof(FloatA);
param.ldb = GetBLeadingElement(b_grid_desc) * sizeof(FloatB);
param.ldc = GetCLeadingElement(c_grid_desc) * sizeof(FloatC);
param.alpha = 1.0f; // TODO
param.Kr = C;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck::index_t tid = omp_get_thread_num();
const ck::index_t tid_n = tid % num_threads_n;
tid /= num_threads_n;
const ck::index_t tid_ho = tid % num_threads_ho;
tid /= num_threads_ho;
const ck::index_t tid_wo = tid % num_threads_wo;
tid /= num_threads_wo;
const ck::index_t tid_k = tid;
for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1)
{
for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
{
// for input
intptr_t i_hi_no_y = i_ho * Sy - Py;
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += m_per_thread)
{
intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
intptr_t i_wi_no_x = i_wo * Sx - Px;
intptr_t i_dx = 0;
intptr_t i_dy = 0;
bool accmulate_c = false;
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// m_per_thread:%d\n",
// tid, i_n, i_ho, i_wo, num_works_n, num_threads_n, Hi, Wi,
// current_wo_size_no_dx, m_per_thread);fflush(stdout);
auto accumulate_dy_dx = [&]() {
i_dx += Dx;
if(i_dx >= X_Dx)
{
i_dx = 0;
i_dy += Dy;
}
};
for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx())
{
intptr_t current_i_wo = i_wo;
intptr_t i_hi = i_hi_no_y + i_dy;
bool run_pad_only = false;
if(i_hi < 0 || i_hi >= Hi)
continue;
intptr_t i_wi = i_wi_no_x + i_dx;
intptr_t current_wo_size = current_wo_size_no_dx;
intptr_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
/* left corner shift
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive
* 1 0
* 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
* 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
* 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
* 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
* -1, 2... 2
*/
if(i_wi < 0)
{
intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number
intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be
// possitive
current_wo_size -= steps_wo_turn_possitive;
// printf("--- current_wo_size:%d, i_wi:%d\n", current_wo_size,
// i_wi);
if(current_wo_size <= 0)
continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating, no
// need to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi)
{
continue;
}
// shrink right wi/wo
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size =
(Wi - 1 - i_wi) / Sx + 1; // NOTE: this be careful why here
// should be compute like this.
if(current_wo_size <= 0)
continue;
}
param.accmulate_c = accmulate_c ? 1 : 0;
accmulate_c = true;
intptr_t current_input_offset =
i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C;
if(pad_wo_size != 0)
{
// manually clear zero. this may and only may need once along
// the gemm_k reduction
intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
intptr_t current_k_block_size = ck::math::min(
K - i_k, (intptr_t)num_works_k_per_thread * n_per_thread);
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo, i_k);
// printf("[%d] pad_wo_size:%d, current_k_block_size:%d,
// offset_c:%d, i_wo:%d\n",
// tid, pad_wo_size, current_k_block_size, offset_c,
// i_wo);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(
&c_block_buf.p_data_[offset_c],
0,
current_k_block_size * pad_wo_size);
}
if(run_pad_only)
continue;
for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread)
{
intptr_t current_k_size =
ck::math::min(K - i_k, (intptr_t)n_per_thread);
const intptr_t offset_a = current_input_offset;
const intptr_t offset_b =
GetBBlockStartOffset(b_grid_desc, i_yxc, i_k);
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo + current_i_wo, i_k);
// printf("[%d] offset_a:%lu, offset_b:%lu, offset_c:%lu,
// i_n:%d, i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d,
// i_wo:%d, current_wo_size:%d, i_n:%d, i_ho:%d, lda:%d,
// ldb:%d\n",
// tid, offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_dy, i_k, i_ho, current_i_wo, current_wo_size, i_n,
// i_ho, param.lda / sizeof(FloatA), param.ldb /
// sizeof(FloatB)); fflush(stdout);
param.p_a = &a_block_buf.p_data_[offset_a];
param.p_b = &b_block_buf.p_data_[offset_b];
param.p_c = &c_block_buf.p_data_[offset_c];
ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size);
}
}
}
// thread block wise fusion
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += 1)
{
const intptr_t n_size =
ck::math::min(K - tid_k * num_works_k_per_thread * n_per_thread,
num_works_k_per_thread * n_per_thread);
if constexpr(CThreadwiseCopy::FuseBias && CThreadwiseCopy::FuseAdd)
{
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo + i_wo, 0);
const intptr_t offset_c0 = 0;
avx2_util::memcpy32_avx2_with_extra_2src(
&c_block_buf.p_data_[offset_c],
&c_block_buf.p_data_[offset_c],
&c0_grid_buf.p_data_[offset_c0],
&c1_grid_buf.p_data_[offset_c],
n_size,
c_element_op);
}
else
{
}
}
}
}
}
}
}
};
} // namespace cpu
} // namespace ck
#endif
...@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_ ...@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr bool FuseBias = true;
static constexpr bool FuseAdd = true;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN( constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index&, const Index&,
...@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN ...@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr bool FuseBias = true;
static constexpr bool FuseAdd = false;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN( constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index&, const Index&,
......
#include <stdlib.h>
#include <utility>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp"
#include "ck/tensor_operation/cpu/device/device_convnd_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
#include "ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXCK8
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
using AddAddRelu = ck::tensor_operation::cpu::element_wise::AddAddRelu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
void add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float, float, float, float, float, PT, PT, AddReluAdd, ConvFwdDefault, 2, 6, 16, false, false, false, true, true, false>({0, 0, 0, DefaultGemmKLoop, LoopOver_MKN}),
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float, float, float, float, float, PT, PT, AddReluAdd, ConvFwdDefault, 2, 6, 16, false, false, false, true, true, false>({0, 0, 0, DefaultGemmKLoop, LoopOver_MNK})
// clang-format on
));
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment