Commit 090ba885 authored by carlushuang's avatar carlushuang
Browse files

add elementwise fusion support

parent 8ce9fe57
......@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
......
#pragma once
#include "data_type_cpu.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace element_wise {
using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t;
struct PassThrough
{
void operator()(float& y, const float& x) const { y = x; }
void operator()(float4_t& y, const float4_t& x) const { y = x; }
void operator()(float8_t& y, const float8_t& x) const { y = x; }
};
struct Add
{
void operator()(float& y, const float& x0, const float& x1) const { y = x0 + x1; }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_add_ps(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_add_ps(x0, x1);
}
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
void operator()(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
}
float alpha_;
float beta_;
};
struct AddRelu
{
void operator()(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = a > 0 ? a : 0;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}
};
#if 0
struct AddHardswish
{
void operator()(float& y, const float& x0, const float& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
#endif
struct AddReluAdd
{
void operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
y = _mm_add_ps(b, x2);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
}
};
#if 0
struct AddHardswishAdd
{
void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
};
#endif
#if 0
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
};
#endif
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x; };
};
template <>
struct UnaryIdentic<float, float, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float4_t, float4_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float8_t, float8_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float8_t, float8_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float4_t, float4_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float8_t, float8_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float8_t, float8_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = abs(x); };
};
template <>
struct UnaryAbs<float4_t, float4_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const
{
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
};
};
template <>
struct UnaryAbs<float8_t, float8_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const
{
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
y = _mm256_and_ps(Mask, x);
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#pragma once
#include "data_type_cpu.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace element_wise {
using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t;
struct PassThrough
{
void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
float Apply(const float& x) const { return x; }
float4_t Apply(const float4_t& x) const { return x; }
float8_t Apply(const float8_t& x) const { return x; }
static constexpr char* Name() { return "PassThrough"; }
};
struct Add
{
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const { return x0 + x1; }
float4_t Apply(const float4_t& x0, const float4_t& x1) const { return _mm_add_ps(x0, x1); }
float8_t Apply(const float8_t& x0, const float8_t& x1) const { return _mm256_add_ps(x0, x1); }
static constexpr char* Name() { return "Add"; }
};
struct Relu
{
void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
float Apply(const float& x) const { return x > 0 ? x : 0; }
float4_t Apply(const float4_t& x) const { return _mm_max_ps(x, _mm_setzero_ps()); }
float8_t Apply(const float8_t& x) const { return _mm256_max_ps(x, _mm256_setzero_ps()); }
static constexpr char* Name() { return "Relu"; }
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const { return alpha_ * x0 + beta_ * x1; }
float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
return _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
}
float8_t Apply(const float8_t& x0, const float8_t& x1) const
{
return _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
}
static constexpr char* Name() { return "AlphaBetaAdd"; }
float alpha_;
float beta_;
};
struct AddRelu
{
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const
{
const float a = x0 + x1;
return a > 0 ? a : 0;
}
float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
return _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
}
float8_t Apply(const float8_t& x0, const float8_t& x1) const
{
return _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}
static constexpr char* Name() { return "AddRelu"; }
};
struct AddReluAdd
{
void operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
y = _mm_add_ps(b, x2);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
}
static constexpr char* Name() { return "AddReluAdd"; }
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x; };
};
template <>
struct UnaryIdentic<float, float, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float4_t, float4_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float8_t, float8_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float8_t, float8_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float4_t, float4_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float8_t, float8_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float8_t, float8_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = abs(x); };
};
template <>
struct UnaryAbs<float4_t, float4_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const
{
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
};
};
template <>
struct UnaryAbs<float8_t, float8_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const
{
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
y = _mm256_and_ps(Mask, x);
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
......@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
......@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
// c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
// ck::make_multi_index(i_mc, i_nc));
}
else
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
......@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc,
a_block_buf,
......@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
}
}
if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
// if constexpr(UseCLocalBuffer)
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
}
}
......@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
......@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
......@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{
c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
blockwise_gemm.Run(a_block_desc,
......@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
}
}
......
......@@ -8,7 +8,7 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "element_wise_operation_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
......@@ -17,7 +17,8 @@ namespace cpu {
namespace avx2_util {
inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
template <typename ElementwiseOp>
void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
......@@ -25,33 +26,33 @@ inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
const float* p_src = reinterpret_cast<const float*>(src);
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0));
_mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8));
_mm256_storeu_ps(p_dst + 0, element_op.Apply(_mm256_loadu_ps(p_src + 0)));
_mm256_storeu_ps(p_dst + 8, element_op.Apply(_mm256_loadu_ps(p_src + 8)));
p_dst += 16;
p_src += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src));
_mm256_storeu_ps(p_dst, element_op.Apply(_mm256_loadu_ps(p_src)));
p_dst += 8;
p_src += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, _mm_loadu_ps(p_src));
_mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src)));
p_dst += 4;
p_src += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, _mm_loadu_si64(p_src));
_mm_storeu_si64(p_dst, element_op.Apply(_mm_loadu_si64(p_src)));
p_dst += 2;
p_src += 2;
}
if(i_n & 1)
{
*p_dst = *p_src;
*p_dst = element_op.Apply(*p_src);
}
}
......@@ -90,8 +91,12 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
}
}
inline void
transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src)
template <typename ElementwiseOp>
void transpose8x8_avx2(void* dst,
ck::index_t stride_dst,
const void* src,
ck::index_t stride_src,
const ElementwiseOp& element_op)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256 r0, r1, r2, r3, r4, r5, r6, r7;
......@@ -100,14 +105,14 @@ transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
r0 = _mm256_loadu_ps(p_src + 0 * stride_src);
r1 = _mm256_loadu_ps(p_src + 1 * stride_src);
r2 = _mm256_loadu_ps(p_src + 2 * stride_src);
r3 = _mm256_loadu_ps(p_src + 3 * stride_src);
r4 = _mm256_loadu_ps(p_src + 4 * stride_src);
r5 = _mm256_loadu_ps(p_src + 5 * stride_src);
r6 = _mm256_loadu_ps(p_src + 6 * stride_src);
r7 = _mm256_loadu_ps(p_src + 7 * stride_src);
r0 = element_op.Apply(_mm256_loadu_ps(p_src + 0 * stride_src));
r1 = element_op.Apply(_mm256_loadu_ps(p_src + 1 * stride_src));
r2 = element_op.Apply(_mm256_loadu_ps(p_src + 2 * stride_src));
r3 = element_op.Apply(_mm256_loadu_ps(p_src + 3 * stride_src));
r4 = element_op.Apply(_mm256_loadu_ps(p_src + 4 * stride_src));
r5 = element_op.Apply(_mm256_loadu_ps(p_src + 5 * stride_src));
r6 = element_op.Apply(_mm256_loadu_ps(p_src + 6 * stride_src));
r7 = element_op.Apply(_mm256_loadu_ps(p_src + 7 * stride_src));
t0 = _mm256_unpacklo_ps(r0, r1);
t1 = _mm256_unpackhi_ps(r0, r1);
......@@ -354,11 +359,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
......@@ -385,14 +391,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block);
avx2_util::memcpy32_avx2(
p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block, element_op_);
i_m_itr -= 8;
p_dst += 8 * k_per_block;
......@@ -400,10 +414,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
avx2_util::memcpy32_avx2(
p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block, element_op_);
p_dst += 4 * k_per_block;
p_src += 4 * C;
......@@ -411,8 +429,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(
p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
p_dst += 2 * k_per_block;
p_src += 2 * C;
......@@ -420,7 +440,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(
p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
}
}
else if constexpr(ConvForwardSpecialization ==
......@@ -431,7 +452,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_ho_itr = i_ho;
while(i_m_itr > 0)
{
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block, element_op_);
p_dst += k_per_block;
i_wo_itr++;
p_src += input_offset_acc_wi;
......@@ -468,7 +489,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block, element_op_);
else
avx2_util::memset32_avx2(p_dst, 0, k_per_block);
......@@ -523,7 +544,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block);
avx2_util::memcpy32_avx2(
p_dst_k, p_src_k, current_k_block, element_op_);
else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
......@@ -730,8 +752,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
......@@ -766,85 +792,85 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
float* p_dst_k = p_dst;
while(i_k_itr >= 8)
{
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK);
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK, element_op_);
p_dst_k += 8 * 8;
p_src_k += 8;
i_k_itr -= 8;
}
if(i_k_itr & 4)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2];
p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2];
p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2];
p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2];
p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2];
p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2];
p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2];
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2];
p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3];
p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3];
p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3];
p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3];
p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3];
p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3];
p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3];
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3];
p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
p_dst_k[1 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 1]);
p_dst_k[1 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 1]);
p_dst_k[1 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 1]);
p_dst_k[1 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 1]);
p_dst_k[1 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 1]);
p_dst_k[1 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 1]);
p_dst_k[1 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 1]);
p_dst_k[1 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 1]);
p_dst_k[2 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 2]);
p_dst_k[2 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 2]);
p_dst_k[2 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 2]);
p_dst_k[2 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 2]);
p_dst_k[2 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 2]);
p_dst_k[2 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 2]);
p_dst_k[2 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 2]);
p_dst_k[2 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 2]);
p_dst_k[3 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 3]);
p_dst_k[3 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 3]);
p_dst_k[3 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 3]);
p_dst_k[3 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 3]);
p_dst_k[3 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 3]);
p_dst_k[3 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 3]);
p_dst_k[3 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 3]);
p_dst_k[3 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 3]);
p_dst_k += 4 * 8;
p_src_k += 4;
}
if(i_k_itr & 2)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
p_dst_k[1 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 1]);
p_dst_k[1 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 1]);
p_dst_k[1 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 1]);
p_dst_k[1 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 1]);
p_dst_k[1 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 1]);
p_dst_k[1 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 1]);
p_dst_k[1 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 1]);
p_dst_k[1 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 1]);
p_dst_k += 2 * 8;
p_src_k += 2;
}
if(i_k_itr & 1)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
}
}
else
......@@ -858,8 +884,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
float v =
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f;
float v = i_current_n_itr < GemmN
? element_op_.Apply(p_src_k[i_sub_n * GemmK + i_sub_k])
: .0f;
p_dst_k[i_sub_k * 8 + i_sub_n] = v;
}
......@@ -949,14 +976,101 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer, typename DstBuffer>
void
Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc& src_desc,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
}
}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunWrite(const SrcDesc& src_desc,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
// src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
// if (true) {
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d,
// dst_offset:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block, dst_offset);fflush(stdout);
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_dst + 2 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_dst + 3 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * DstGemmN, p_dst + 4 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * DstGemmN, p_dst + 5 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * DstGemmN, p_dst + 6 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * DstGemmN, p_dst + 7 * DstGemmN, current_n, element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_dst + 2 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_dst + 3 * DstGemmN, current_n, element_op_);
p_dst += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
p_dst += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
}
}
}
else
{
const ck::index_t m_per_block =
......@@ -978,14 +1092,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n);
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n, element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
......@@ -994,10 +1116,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n, element_op_);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
......@@ -1005,8 +1131,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
......@@ -1014,7 +1142,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
......
......@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType,
......@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
......@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......
......@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......
......@@ -12,6 +12,11 @@
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using F32 = float;
using F16 = ck::half_t;
......@@ -22,6 +27,7 @@ namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
......@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T>
static bool
......@@ -295,9 +315,16 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
......@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
......@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
}
if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment