Commit 090ba885 authored by carlushuang's avatar carlushuang
Browse files

add elementwise fusion support

parent 8ce9fe57
...@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer) << "_C" << string_local_buffer(UseCLocalBuffer)
; ;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on // clang-format on
return str.str(); return str.str();
......
#pragma once #pragma once
#include "data_type_cpu.hpp" #include "data_type_cpu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace element_wise { namespace element_wise {
using float8_t = ck::cpu::float8_t; using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t; using float4_t = ck::cpu::float4_t;
struct PassThrough struct PassThrough
{ {
void operator()(float& y, const float& x) const { y = x; } void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = x; } void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = x; } void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
};
float Apply(const float& x) const { return x; }
struct Add float4_t Apply(const float4_t& x) const { return x; }
{ float8_t Apply(const float8_t& x) const { return x; }
void operator()(float& y, const float& x0, const float& x1) const { y = x0 + x1; }
static constexpr char* Name() { return "PassThrough"; }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const };
{
y = _mm_add_ps(x0, x1); struct Add
} {
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{ void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
y = _mm256_add_ps(x0, x1); {
} y = Apply(x0, x1);
}; }
struct AlphaBetaAdd void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{ {
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} y = Apply(x0, x1);
}
void operator()(float& y, const float& x0, const float& x1) const
{ float Apply(const float& x0, const float& x1) const { return x0 + x1; }
y = alpha_ * x0 + beta_ * x1; float4_t Apply(const float4_t& x0, const float4_t& x1) const { return _mm_add_ps(x0, x1); }
} float8_t Apply(const float8_t& x0, const float8_t& x1) const { return _mm256_add_ps(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const static constexpr char* Name() { return "Add"; }
{ };
y = _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
} struct Relu
{
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const void operator()(float& y, const float& x) const { y = Apply(x); }
{ void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
y = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)), void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
} float Apply(const float& x) const { return x > 0 ? x : 0; }
float4_t Apply(const float4_t& x) const { return _mm_max_ps(x, _mm_setzero_ps()); }
float alpha_; float8_t Apply(const float8_t& x) const { return _mm256_max_ps(x, _mm256_setzero_ps()); }
float beta_;
}; static constexpr char* Name() { return "Relu"; }
};
struct AddRelu
{ struct AlphaBetaAdd
void operator()(float& y, const float& x0, const float& x1) const {
{ AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
const float a = x0 + x1;
y = a > 0 ? a : 0; void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const {
{ y = Apply(x0, x1);
y = _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps()); }
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const {
{ y = Apply(x0, x1);
y = _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps()); }
}
}; float Apply(const float& x0, const float& x1) const { return alpha_ * x0 + beta_ * x1; }
#if 0 float4_t Apply(const float4_t& x0, const float4_t& x1) const
struct AddHardswish {
{ return _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
void operator()(float& y, const float& x0, const float& x1) const }
{
float a = x0 + x1; float8_t Apply(const float8_t& x0, const float8_t& x1) const
float b = a + float{3}; {
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; return _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
y = c; _mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
} }
void static constexpr char* Name() { return "AlphaBetaAdd"; }
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{ float alpha_;
float a = x0 + x1; float beta_;
float b = a + float{3}; };
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c; struct AddRelu
} {
}; void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
#endif
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
struct AddReluAdd {
{ y = Apply(x0, x1);
void operator()(float& y, const float& x0, const float& x1, const float& x2) const }
{
float a = x0 + x1; void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
float b = a > 0 ? a : 0; {
float c = b + x2; y = Apply(x0, x1);
y = c; }
}
float Apply(const float& x0, const float& x1) const
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const {
{ const float a = x0 + x1;
float4_t a = _mm_add_ps(x0, x1); return a > 0 ? a : 0;
float4_t b = _mm_max_ps(a, _mm_setzero_ps()); }
y = _mm_add_ps(b, x2);
} float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const return _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
{ }
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps()); float8_t Apply(const float8_t& x0, const float8_t& x1) const
y = _mm256_add_ps(b, x2); {
} return _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}; }
#if 0 static constexpr char* Name() { return "AddRelu"; }
struct AddHardswishAdd };
{
void struct AddReluAdd
operator()(float& y, const float& x0, const float& x1, const float& x2) const {
{ void operator()(float& y, const float& x0, const float& x1, const float& x2) const
float a = x0 + x1; {
float b = a + float{3}; float a = x0 + x1;
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; float b = a > 0 ? a : 0;
float d = c + x2; float c = b + x2;
y = d; y = c;
} }
void void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const {
{ float4_t a = _mm_add_ps(x0, x1);
float a = x0 + x1; float4_t b = _mm_max_ps(a, _mm_setzero_ps());
float b = a + float{3}; y = _mm_add_ps(b, x2);
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; }
float d = c + x2;
y = d; void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
} {
}; float8_t a = _mm256_add_ps(x0, x1);
#endif float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
#if 0 }
struct RequantReluRequant
{ static constexpr char* Name() { return "AddReluAdd"; }
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu };
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) // Unary operators are usually called element-wisely before/after the reduction is executed on the
{ // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
}
template <typename Y, typename X, bool HasDividing = false>
void operator()(int8_t& y, const int& x) const struct UnaryIdentic;
{
float gemm_requant = scaleGemm_ * static_cast<float>(x); template <>
float relu = gemm_requant > 0 ? gemm_requant : 0; struct UnaryIdentic<float, float, false>
float relu_requant = scaleRelu_ * relu; {
y = static_cast<int8_t>(relu_requant > 127 ? 127 UnaryIdentic(const int32_t divider = 1) { (void)divider; };
: relu_requant < -128 ? -128 : relu_requant);
} void operator()(float& y, const float& x) const { y = x; };
};
// for reference_gemm
void operator()(float& y, const float& x) const template <>
{ struct UnaryIdentic<float, float, true>
float gemm_requant = scaleGemm_ * x; {
float relu = gemm_requant > 0 ? gemm_requant : 0; UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127 void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
: relu_requant < -128 ? -128 : relu_requant);
} int32_t divider_ = 1;
};
float scaleGemm_;
float scaleRelu_; template <>
}; struct UnaryIdentic<float4_t, float4_t, false>
#endif {
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic; template <>
struct UnaryIdentic<float4_t, float4_t, true>
template <> {
struct UnaryIdentic<float, float, false> UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; }; void operator()(float4_t& y, const float4_t& x) const
{
void operator()(float& y, const float& x) const { y = x; }; y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
}; };
template <> int32_t divider_ = 1;
struct UnaryIdentic<float, float, true> };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; template <>
struct UnaryIdentic<float8_t, float8_t, false>
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); }; {
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, false> template <>
{ struct UnaryIdentic<float8_t, float8_t, true>
UnaryIdentic(const int32_t divider = 1) { (void)divider; }; {
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
}; void operator()(float8_t& y, const float8_t& x) const
{
template <> y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
struct UnaryIdentic<float4_t, float4_t, true> };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float4_t& y, const float4_t& x) const
{ template <typename Y, typename X, bool HasDividing = false>
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_))); struct UnarySquare;
};
template <>
int32_t divider_ = 1; struct UnarySquare<float, float, false>
}; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
template <>
struct UnaryIdentic<float8_t, float8_t, false> void operator()(float& y, const float& x) const { y = x * x; };
{ };
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
template <>
void operator()(float8_t& y, const float8_t& x) const { y = x; }; struct UnarySquare<float, float, true>
}; {
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
template <>
struct UnaryIdentic<float8_t, float8_t, true> void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float8_t& y, const float8_t& x) const
{ template <>
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_))); struct UnarySquare<float4_t, float4_t, false>
}; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare; template <>
struct UnarySquare<float4_t, float4_t, true>
template <> {
struct UnarySquare<float, float, false> UnarySquare(const int32_t divider = 1) { divider_ = divider; };
{
UnarySquare(const int32_t divider = 1) { (void)divider; }; void operator()(float4_t& y, const float4_t& x) const
{
void operator()(float& y, const float& x) const { y = x * x; }; y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
}; };
template <> int32_t divider_ = 1;
struct UnarySquare<float, float, true> };
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; template <>
struct UnarySquare<float8_t, float8_t, false>
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); }; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, false> template <>
{ struct UnarySquare<float8_t, float8_t, true>
UnarySquare(const int32_t divider = 1) { (void)divider; }; {
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
}; void operator()(float8_t& y, const float8_t& x) const
{
template <> y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
struct UnarySquare<float4_t, float4_t, true> };
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float4_t& y, const float4_t& x) const
{ template <typename Y, typename X>
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_))); struct UnaryAbs;
};
template <>
int32_t divider_ = 1; struct UnaryAbs<float, float>
}; {
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <>
struct UnarySquare<float8_t, float8_t, false> void operator()(float& y, const float& x) const { y = abs(x); };
{ };
UnarySquare(const int32_t divider = 1) { (void)divider; };
template <>
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); }; struct UnaryAbs<float4_t, float4_t>
}; {
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <>
struct UnarySquare<float8_t, float8_t, true> void operator()(float4_t& y, const float4_t& x) const
{ {
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; __m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
void operator()(float8_t& y, const float8_t& x) const };
{ };
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
}; template <>
struct UnaryAbs<float8_t, float8_t>
int32_t divider_ = 1; {
}; UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <typename Y, typename X> void operator()(float8_t& y, const float8_t& x) const
struct UnaryAbs; {
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
template <> y = _mm256_and_ps(Mask, x);
struct UnaryAbs<float, float> };
{ };
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <typename Y, typename X>
void operator()(float& y, const float& x) const { y = abs(x); }; struct UnarySqrt;
};
template <>
template <> struct UnarySqrt<float, float>
struct UnaryAbs<float4_t, float4_t> {
{ void operator()(float& y, const float& x) const { y = sqrtf(x); };
UnaryAbs(const int32_t divider = 1) { (void)divider; }; };
void operator()(float4_t& y, const float4_t& x) const template <>
{ struct UnarySqrt<float4_t, float4_t>
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000)); {
y = _mm_and_ps(Mask, x); void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
}; };
};
template <>
template <> struct UnarySqrt<float8_t, float8_t>
struct UnaryAbs<float8_t, float8_t> {
{ void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
UnaryAbs(const int32_t divider = 1) { (void)divider; }; };
void operator()(float8_t& y, const float8_t& x) const } // namespace element_wise
{ } // namespace cpu
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000)); } // namespace tensor_operation
y = _mm256_and_ps(Mask, x); } // namespace ck
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN ...@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
} }
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
...@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN ...@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc; UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, // c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc)); // ck::make_multi_index(i_mc, i_nc));
} }
else else
{ {
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc)); ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
...@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN ...@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_threadwise_copy.RunRead(a_grid_desc,
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
...@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
} }
} }
if constexpr(UseCLocalBuffer) // if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
} }
} }
...@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0)); ck::make_multi_index(0, i_kc, 0));
...@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run( b_threadwise_copy.RunRead(b_grid_desc,
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size) ? GetCBlockDescriptor(mc_size, nc_size)
...@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{ {
c_threadwise_copy.SetSrcSliceOrigin( c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc)); c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run( c_threadwise_copy.RunRead(c_block_desc,
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
...@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN ...@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc != 0); i_kc != 0);
if((i_nc + n_per_block) < GemmN) if((i_nc + n_per_block) < GemmN)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
c_threadwise_copy.SetDstSliceOrigin( c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc)); c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
} }
} }
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp" #include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <immintrin.h> #include "element_wise_operation_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h> #include <immintrin.h>
...@@ -17,7 +17,8 @@ namespace cpu { ...@@ -17,7 +17,8 @@ namespace cpu {
namespace avx2_util { namespace avx2_util {
inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n) template <typename ElementwiseOp>
void memcpy32_avx2(void* dst, const void* src, const ck::index_t n, const ElementwiseOp& element_op)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
ck::index_t i_n = n; ck::index_t i_n = n;
...@@ -25,33 +26,33 @@ inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n) ...@@ -25,33 +26,33 @@ inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
const float* p_src = reinterpret_cast<const float*>(src); const float* p_src = reinterpret_cast<const float*>(src);
while(i_n >= 16) while(i_n >= 16)
{ {
_mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0)); _mm256_storeu_ps(p_dst + 0, element_op.Apply(_mm256_loadu_ps(p_src + 0)));
_mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8)); _mm256_storeu_ps(p_dst + 8, element_op.Apply(_mm256_loadu_ps(p_src + 8)));
p_dst += 16; p_dst += 16;
p_src += 16; p_src += 16;
i_n -= 16; i_n -= 16;
} }
if(i_n & 8) if(i_n & 8)
{ {
_mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src)); _mm256_storeu_ps(p_dst, element_op.Apply(_mm256_loadu_ps(p_src)));
p_dst += 8; p_dst += 8;
p_src += 8; p_src += 8;
} }
if(i_n & 4) if(i_n & 4)
{ {
_mm_storeu_ps(p_dst, _mm_loadu_ps(p_src)); _mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src)));
p_dst += 4; p_dst += 4;
p_src += 4; p_src += 4;
} }
if(i_n & 2) if(i_n & 2)
{ {
_mm_storeu_si64(p_dst, _mm_loadu_si64(p_src)); _mm_storeu_si64(p_dst, element_op.Apply(_mm_loadu_si64(p_src)));
p_dst += 2; p_dst += 2;
p_src += 2; p_src += 2;
} }
if(i_n & 1) if(i_n & 1)
{ {
*p_dst = *p_src; *p_dst = element_op.Apply(*p_src);
} }
} }
...@@ -90,8 +91,12 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) ...@@ -90,8 +91,12 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
} }
} }
inline void template <typename ElementwiseOp>
transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src) void transpose8x8_avx2(void* dst,
ck::index_t stride_dst,
const void* src,
ck::index_t stride_src,
const ElementwiseOp& element_op)
{ {
// TODO: use vinsertf128 for better port usage. vpermf128 is slow // TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256 r0, r1, r2, r3, r4, r5, r6, r7; __m256 r0, r1, r2, r3, r4, r5, r6, r7;
...@@ -100,14 +105,14 @@ transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_ ...@@ -100,14 +105,14 @@ transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_
float* p_dst = reinterpret_cast<float*>(dst); float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src); const float* p_src = reinterpret_cast<const float*>(src);
r0 = _mm256_loadu_ps(p_src + 0 * stride_src); r0 = element_op.Apply(_mm256_loadu_ps(p_src + 0 * stride_src));
r1 = _mm256_loadu_ps(p_src + 1 * stride_src); r1 = element_op.Apply(_mm256_loadu_ps(p_src + 1 * stride_src));
r2 = _mm256_loadu_ps(p_src + 2 * stride_src); r2 = element_op.Apply(_mm256_loadu_ps(p_src + 2 * stride_src));
r3 = _mm256_loadu_ps(p_src + 3 * stride_src); r3 = element_op.Apply(_mm256_loadu_ps(p_src + 3 * stride_src));
r4 = _mm256_loadu_ps(p_src + 4 * stride_src); r4 = element_op.Apply(_mm256_loadu_ps(p_src + 4 * stride_src));
r5 = _mm256_loadu_ps(p_src + 5 * stride_src); r5 = element_op.Apply(_mm256_loadu_ps(p_src + 5 * stride_src));
r6 = _mm256_loadu_ps(p_src + 6 * stride_src); r6 = element_op.Apply(_mm256_loadu_ps(p_src + 6 * stride_src));
r7 = _mm256_loadu_ps(p_src + 7 * stride_src); r7 = element_op.Apply(_mm256_loadu_ps(p_src + 7 * stride_src));
t0 = _mm256_unpacklo_ps(r0, r1); t0 = _mm256_unpacklo_ps(r0, r1);
t1 = _mm256_unpackhi_ps(r0, r1); t1 = _mm256_unpackhi_ps(r0, r1);
...@@ -354,11 +359,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -354,11 +359,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
void SetDstSliceOrigin(const DstDesc&, const Index&) {} void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void Run(const SrcDesc& src_desc, void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf,
const SliceLengths& slice_length)
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
...@@ -385,14 +391,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -385,14 +391,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
// standard 8-4-2-1 pattern // standard 8-4-2-1 pattern
while(i_m_itr >= 8) while(i_m_itr >= 8)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block); p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block); p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block); p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block, element_op_);
i_m_itr -= 8; i_m_itr -= 8;
p_dst += 8 * k_per_block; p_dst += 8 * k_per_block;
...@@ -400,10 +414,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -400,10 +414,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
} }
if(i_m_itr & 4) if(i_m_itr & 4)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block); p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block, element_op_);
p_dst += 4 * k_per_block; p_dst += 4 * k_per_block;
p_src += 4 * C; p_src += 4 * C;
...@@ -411,8 +429,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -411,8 +429,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_m_itr & 2) if(i_m_itr & 2)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block, element_op_);
p_dst += 2 * k_per_block; p_dst += 2 * k_per_block;
p_src += 2 * C; p_src += 2 * C;
...@@ -420,7 +440,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -420,7 +440,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_m_itr & 1) if(i_m_itr & 1)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(
p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block, element_op_);
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
...@@ -431,7 +452,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -431,7 +452,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_ho_itr = i_ho; ck::index_t i_ho_itr = i_ho;
while(i_m_itr > 0) while(i_m_itr > 0)
{ {
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block); avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block, element_op_);
p_dst += k_per_block; p_dst += k_per_block;
i_wo_itr++; i_wo_itr++;
p_src += input_offset_acc_wi; p_src += input_offset_acc_wi;
...@@ -468,7 +489,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -468,7 +489,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block); avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block, element_op_);
else else
avx2_util::memset32_avx2(p_dst, 0, k_per_block); avx2_util::memset32_avx2(p_dst, 0, k_per_block);
...@@ -523,7 +544,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -523,7 +544,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block); avx2_util::memcpy32_avx2(
p_dst_k, p_src_k, current_k_block, element_op_);
else else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block); avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
...@@ -730,8 +752,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC ...@@ -730,8 +752,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
void SetDstSliceOrigin(const DstDesc&, const Index&) {} void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf) void RunRead(const SrcDesc&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
...@@ -766,85 +792,85 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC ...@@ -766,85 +792,85 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
float* p_dst_k = p_dst; float* p_dst_k = p_dst;
while(i_k_itr >= 8) while(i_k_itr >= 8)
{ {
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK); avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK, element_op_);
p_dst_k += 8 * 8; p_dst_k += 8 * 8;
p_src_k += 8; p_src_k += 8;
i_k_itr -= 8; i_k_itr -= 8;
} }
if(i_k_itr & 4) if(i_k_itr & 4)
{ {
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1]; p_dst_k[1 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 1]);
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1]; p_dst_k[1 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 1]);
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1]; p_dst_k[1 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 1]);
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1]; p_dst_k[1 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 1]);
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1]; p_dst_k[1 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 1]);
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1]; p_dst_k[1 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 1]);
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1]; p_dst_k[1 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 1]);
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1]; p_dst_k[1 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 1]);
p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2]; p_dst_k[2 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 2]);
p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2]; p_dst_k[2 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 2]);
p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2]; p_dst_k[2 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 2]);
p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2]; p_dst_k[2 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 2]);
p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2]; p_dst_k[2 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 2]);
p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2]; p_dst_k[2 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 2]);
p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2]; p_dst_k[2 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 2]);
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2]; p_dst_k[2 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 2]);
p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3]; p_dst_k[3 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 3]);
p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3]; p_dst_k[3 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 3]);
p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3]; p_dst_k[3 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 3]);
p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3]; p_dst_k[3 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 3]);
p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3]; p_dst_k[3 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 3]);
p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3]; p_dst_k[3 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 3]);
p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3]; p_dst_k[3 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 3]);
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3]; p_dst_k[3 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 3]);
p_dst_k += 4 * 8; p_dst_k += 4 * 8;
p_src_k += 4; p_src_k += 4;
} }
if(i_k_itr & 2) if(i_k_itr & 2)
{ {
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1]; p_dst_k[1 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 1]);
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1]; p_dst_k[1 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 1]);
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1]; p_dst_k[1 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 1]);
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1]; p_dst_k[1 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 1]);
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1]; p_dst_k[1 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 1]);
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1]; p_dst_k[1 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 1]);
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1]; p_dst_k[1 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 1]);
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1]; p_dst_k[1 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 1]);
p_dst_k += 2 * 8; p_dst_k += 2 * 8;
p_src_k += 2; p_src_k += 2;
} }
if(i_k_itr & 1) if(i_k_itr & 1)
{ {
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; p_dst_k[0 * 8 + 0] = element_op_.Apply(p_src_k[0 * GemmK + 0]);
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 1] = element_op_.Apply(p_src_k[1 * GemmK + 0]);
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 2] = element_op_.Apply(p_src_k[2 * GemmK + 0]);
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 3] = element_op_.Apply(p_src_k[3 * GemmK + 0]);
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 4] = element_op_.Apply(p_src_k[4 * GemmK + 0]);
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 5] = element_op_.Apply(p_src_k[5 * GemmK + 0]);
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 6] = element_op_.Apply(p_src_k[6 * GemmK + 0]);
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 7] = element_op_.Apply(p_src_k[7 * GemmK + 0]);
} }
} }
else else
...@@ -858,8 +884,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC ...@@ -858,8 +884,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{ {
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n; ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
float v = float v = i_current_n_itr < GemmN
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f; ? element_op_.Apply(p_src_k[i_sub_n * GemmK + i_sub_k])
: .0f;
p_dst_k[i_sub_k * 8 + i_sub_n] = v; p_dst_k[i_sub_k * 8 + i_sub_n] = v;
} }
...@@ -949,14 +976,101 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -949,14 +976,101 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n; dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
} }
template <typename SrcBuffer, typename DstBuffer> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void void RunRead(const SrcDesc& src_desc,
Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf) SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset; src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
} }
}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunWrite(const SrcDesc& src_desc,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
// src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
// if (true) {
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d,
// dst_offset:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block, dst_offset);fflush(stdout);
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_dst + 2 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_dst + 3 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * DstGemmN, p_dst + 4 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * DstGemmN, p_dst + 5 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * DstGemmN, p_dst + 6 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * DstGemmN, p_dst + 7 * DstGemmN, current_n, element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_dst + 2 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_dst + 3 * DstGemmN, current_n, element_op_);
p_dst += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_dst + 1 * DstGemmN, current_n, element_op_);
p_dst += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_dst + 0 * DstGemmN, current_n, element_op_);
}
}
}
else else
{ {
const ck::index_t m_per_block = const ck::index_t m_per_block =
...@@ -978,14 +1092,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -978,14 +1092,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
// standard 8-4-2-1 pattern // standard 8-4-2-1 pattern
while(i_m_itr >= 8) while(i_m_itr >= 8)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n); p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n); p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n); p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n, element_op_);
i_m_itr -= 8; i_m_itr -= 8;
p_dst += 8 * DstGemmN; p_dst += 8 * DstGemmN;
...@@ -994,10 +1116,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -994,10 +1116,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 4) if(i_m_itr & 4)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n); p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n, element_op_);
p_dst += 4 * DstGemmN; p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block; p_src += 4 * n_per_block;
...@@ -1005,8 +1131,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -1005,8 +1131,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 2) if(i_m_itr & 2)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); avx2_util::memcpy32_avx2(
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n, element_op_);
p_dst += 2 * DstGemmN; p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block; p_src += 2 * n_per_block;
...@@ -1014,7 +1142,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -1014,7 +1142,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if(i_m_itr & 1) if(i_m_itr & 1)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); avx2_util::memcpy32_avx2(
p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n, element_op_);
} }
// printf("xxxx %d\n",__LINE__);fflush(stdout); // printf("xxxx %d\n",__LINE__);fflush(stdout);
......
...@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; / ...@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch = using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType, ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType, WeiType,
...@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple< ...@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
...@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( ...@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c( ...@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
#include <omp.h> #include <omp.h>
#define AVX2_DATA_ALIGNMENT 32 #define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -22,6 +27,7 @@ namespace device { ...@@ -22,6 +27,7 @@ namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
...@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c( ...@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( ...@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T> template <typename T>
static bool static bool
...@@ -295,9 +315,16 @@ int main(int argc, char* argv[]) ...@@ -295,9 +315,16 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device:: using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>; DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
...@@ -306,6 +333,7 @@ int main(int argc, char* argv[]) ...@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1) if(omp_get_max_threads() > 1)
{ {
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
...@@ -322,6 +350,25 @@ int main(int argc, char* argv[]) ...@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs); add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
} }
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
} }
if(conv_ptrs.size() <= 0) if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment