Commit 090ba885 authored by carlushuang's avatar carlushuang
Browse files

add elementwise fusion support

parent 8ce9fe57
...@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer) << "_C" << string_local_buffer(UseCLocalBuffer)
; ;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on // clang-format on
return str.str(); return str.str();
......
#pragma once #pragma once
#include "data_type_cpu.hpp" #include "data_type_cpu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace element_wise { namespace element_wise {
using float8_t = ck::cpu::float8_t; using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t; using float4_t = ck::cpu::float4_t;
struct PassThrough struct PassThrough
{ {
void operator()(float& y, const float& x) const { y = x; } void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = x; } void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = x; } void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
};
float Apply(const float& x) const { return x; }
struct Add float4_t Apply(const float4_t& x) const { return x; }
{ float8_t Apply(const float8_t& x) const { return x; }
void operator()(float& y, const float& x0, const float& x1) const { y = x0 + x1; }
static constexpr char* Name() { return "PassThrough"; }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const };
{
y = _mm_add_ps(x0, x1); struct Add
} {
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{ void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
y = _mm256_add_ps(x0, x1); {
} y = Apply(x0, x1);
}; }
struct AlphaBetaAdd void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{ {
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} y = Apply(x0, x1);
}
void operator()(float& y, const float& x0, const float& x1) const
{ float Apply(const float& x0, const float& x1) const { return x0 + x1; }
y = alpha_ * x0 + beta_ * x1; float4_t Apply(const float4_t& x0, const float4_t& x1) const { return _mm_add_ps(x0, x1); }
} float8_t Apply(const float8_t& x0, const float8_t& x1) const { return _mm256_add_ps(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const static constexpr char* Name() { return "Add"; }
{ };
y = _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
} struct Relu
{
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const void operator()(float& y, const float& x) const { y = Apply(x); }
{ void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
y = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)), void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
} float Apply(const float& x) const { return x > 0 ? x : 0; }
float4_t Apply(const float4_t& x) const { return _mm_max_ps(x, _mm_setzero_ps()); }
float alpha_; float8_t Apply(const float8_t& x) const { return _mm256_max_ps(x, _mm256_setzero_ps()); }
float beta_;
}; static constexpr char* Name() { return "Relu"; }
};
struct AddRelu
{ struct AlphaBetaAdd
void operator()(float& y, const float& x0, const float& x1) const {
{ AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
const float a = x0 + x1;
y = a > 0 ? a : 0; void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const {
{ y = Apply(x0, x1);
y = _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps()); }
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const {
{ y = Apply(x0, x1);
y = _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps()); }
}
}; float Apply(const float& x0, const float& x1) const { return alpha_ * x0 + beta_ * x1; }
#if 0 float4_t Apply(const float4_t& x0, const float4_t& x1) const
struct AddHardswish {
{ return _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
void operator()(float& y, const float& x0, const float& x1) const }
{
float a = x0 + x1; float8_t Apply(const float8_t& x0, const float8_t& x1) const
float b = a + float{3}; {
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; return _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
y = c; _mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
} }
void static constexpr char* Name() { return "AlphaBetaAdd"; }
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{ float alpha_;
float a = x0 + x1; float beta_;
float b = a + float{3}; };
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c; struct AddRelu
} {
}; void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
#endif
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
struct AddReluAdd {
{ y = Apply(x0, x1);
void operator()(float& y, const float& x0, const float& x1, const float& x2) const }
{
float a = x0 + x1; void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
float b = a > 0 ? a : 0; {
float c = b + x2; y = Apply(x0, x1);
y = c; }
}
float Apply(const float& x0, const float& x1) const
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const {
{ const float a = x0 + x1;
float4_t a = _mm_add_ps(x0, x1); return a > 0 ? a : 0;
float4_t b = _mm_max_ps(a, _mm_setzero_ps()); }
y = _mm_add_ps(b, x2);
} float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const return _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
{ }
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps()); float8_t Apply(const float8_t& x0, const float8_t& x1) const
y = _mm256_add_ps(b, x2); {
} return _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}; }
#if 0 static constexpr char* Name() { return "AddRelu"; }
struct AddHardswishAdd };
{
void struct AddReluAdd
operator()(float& y, const float& x0, const float& x1, const float& x2) const {
{ void operator()(float& y, const float& x0, const float& x1, const float& x2) const
float a = x0 + x1; {
float b = a + float{3}; float a = x0 + x1;
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; float b = a > 0 ? a : 0;
float d = c + x2; float c = b + x2;
y = d; y = c;
} }
void void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const {
{ float4_t a = _mm_add_ps(x0, x1);
float a = x0 + x1; float4_t b = _mm_max_ps(a, _mm_setzero_ps());
float b = a + float{3}; y = _mm_add_ps(b, x2);
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; }
float d = c + x2;
y = d; void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
} {
}; float8_t a = _mm256_add_ps(x0, x1);
#endif float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
#if 0 }
struct RequantReluRequant
{ static constexpr char* Name() { return "AddReluAdd"; }
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu };
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) // Unary operators are usually called element-wisely before/after the reduction is executed on the
{ // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
}
template <typename Y, typename X, bool HasDividing = false>
void operator()(int8_t& y, const int& x) const struct UnaryIdentic;
{
float gemm_requant = scaleGemm_ * static_cast<float>(x); template <>
float relu = gemm_requant > 0 ? gemm_requant : 0; struct UnaryIdentic<float, float, false>
float relu_requant = scaleRelu_ * relu; {
y = static_cast<int8_t>(relu_requant > 127 ? 127 UnaryIdentic(const int32_t divider = 1) { (void)divider; };
: relu_requant < -128 ? -128 : relu_requant);
} void operator()(float& y, const float& x) const { y = x; };
};
// for reference_gemm
void operator()(float& y, const float& x) const template <>
{ struct UnaryIdentic<float, float, true>
float gemm_requant = scaleGemm_ * x; {
float relu = gemm_requant > 0 ? gemm_requant : 0; UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127 void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
: relu_requant < -128 ? -128 : relu_requant);
} int32_t divider_ = 1;
};
float scaleGemm_;
float scaleRelu_; template <>
}; struct UnaryIdentic<float4_t, float4_t, false>
#endif {
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic; template <>
struct UnaryIdentic<float4_t, float4_t, true>
template <> {
struct UnaryIdentic<float, float, false> UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; }; void operator()(float4_t& y, const float4_t& x) const
{
void operator()(float& y, const float& x) const { y = x; }; y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
}; };
template <> int32_t divider_ = 1;
struct UnaryIdentic<float, float, true> };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; template <>
struct UnaryIdentic<float8_t, float8_t, false>
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); }; {
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, false> template <>
{ struct UnaryIdentic<float8_t, float8_t, true>
UnaryIdentic(const int32_t divider = 1) { (void)divider; }; {
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
}; void operator()(float8_t& y, const float8_t& x) const
{
template <> y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
struct UnaryIdentic<float4_t, float4_t, true> };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float4_t& y, const float4_t& x) const
{ template <typename Y, typename X, bool HasDividing = false>
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_))); struct UnarySquare;
};
template <>
int32_t divider_ = 1; struct UnarySquare<float, float, false>
}; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
template <>
struct UnaryIdentic<float8_t, float8_t, false> void operator()(float& y, const float& x) const { y = x * x; };
{ };
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
template <>
void operator()(float8_t& y, const float8_t& x) const { y = x; }; struct UnarySquare<float, float, true>
}; {
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
template <>
struct UnaryIdentic<float8_t, float8_t, true> void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float8_t& y, const float8_t& x) const
{ template <>
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_))); struct UnarySquare<float4_t, float4_t, false>
}; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare; template <>
struct UnarySquare<float4_t, float4_t, true>
template <> {
struct UnarySquare<float, float, false> UnarySquare(const int32_t divider = 1) { divider_ = divider; };
{
UnarySquare(const int32_t divider = 1) { (void)divider; }; void operator()(float4_t& y, const float4_t& x) const
{
void operator()(float& y, const float& x) const { y = x * x; }; y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
}; };
template <> int32_t divider_ = 1;
struct UnarySquare<float, float, true> };
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; template <>
struct UnarySquare<float8_t, float8_t, false>
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); }; {
UnarySquare(const int32_t divider = 1) { (void)divider; };
int32_t divider_ = 1;
}; void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, false> template <>
{ struct UnarySquare<float8_t, float8_t, true>
UnarySquare(const int32_t divider = 1) { (void)divider; }; {
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
}; void operator()(float8_t& y, const float8_t& x) const
{
template <> y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
struct UnarySquare<float4_t, float4_t, true> };
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; int32_t divider_ = 1;
};
void operator()(float4_t& y, const float4_t& x) const
{ template <typename Y, typename X>
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_))); struct UnaryAbs;
};
template <>
int32_t divider_ = 1; struct UnaryAbs<float, float>
}; {
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <>
struct UnarySquare<float8_t, float8_t, false> void operator()(float& y, const float& x) const { y = abs(x); };
{ };
UnarySquare(const int32_t divider = 1) { (void)divider; };
template <>
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); }; struct UnaryAbs<float4_t, float4_t>
}; {
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <>
struct UnarySquare<float8_t, float8_t, true> void operator()(float4_t& y, const float4_t& x) const
{ {
UnarySquare(const int32_t divider = 1) { divider_ = divider; }; __m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
void operator()(float8_t& y, const float8_t& x) const };
{ };
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
}; template <>
struct UnaryAbs<float8_t, float8_t>
int32_t divider_ = 1; {
}; UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <typename Y, typename X> void operator()(float8_t& y, const float8_t& x) const
struct UnaryAbs; {
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
template <> y = _mm256_and_ps(Mask, x);
struct UnaryAbs<float, float> };
{ };
UnaryAbs(const int32_t divider = 1) { (void)divider; };
template <typename Y, typename X>
void operator()(float& y, const float& x) const { y = abs(x); }; struct UnarySqrt;
};
template <>
template <> struct UnarySqrt<float, float>
struct UnaryAbs<float4_t, float4_t> {
{ void operator()(float& y, const float& x) const { y = sqrtf(x); };
UnaryAbs(const int32_t divider = 1) { (void)divider; }; };
void operator()(float4_t& y, const float4_t& x) const template <>
{ struct UnarySqrt<float4_t, float4_t>
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000)); {
y = _mm_and_ps(Mask, x); void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
}; };
};
template <>
template <> struct UnarySqrt<float8_t, float8_t>
struct UnaryAbs<float8_t, float8_t> {
{ void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
UnaryAbs(const int32_t divider = 1) { (void)divider; }; };
void operator()(float8_t& y, const float8_t& x) const } // namespace element_wise
{ } // namespace cpu
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000)); } // namespace tensor_operation
y = _mm256_and_ps(Mask, x); } // namespace ck
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN ...@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
} }
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
...@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN ...@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc; UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, // c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc)); // ck::make_multi_index(i_mc, i_nc));
} }
else else
{ {
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc)); ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
...@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN ...@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_threadwise_copy.RunRead(a_grid_desc,
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
...@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
} }
} }
if constexpr(UseCLocalBuffer) // if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
} }
} }
...@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0)); ck::make_multi_index(0, i_kc, 0));
...@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run( b_threadwise_copy.RunRead(b_grid_desc,
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size) ? GetCBlockDescriptor(mc_size, nc_size)
...@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN ...@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{ {
c_threadwise_copy.SetSrcSliceOrigin( c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc)); c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run( c_threadwise_copy.RunRead(c_block_desc,
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
} }
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
...@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN ...@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc != 0); i_kc != 0);
if((i_nc + n_per_block) < GemmN) if((i_nc + n_per_block) < GemmN)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
c_threadwise_copy.SetDstSliceOrigin( c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc)); c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
} }
} }
......
...@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; / ...@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch = using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType, ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType, WeiType,
...@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple< ...@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
...@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( ...@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c( ...@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
#include <omp.h> #include <omp.h>
#define AVX2_DATA_ALIGNMENT 32 #define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -22,6 +27,7 @@ namespace device { ...@@ -22,6 +27,7 @@ namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
...@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c( ...@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt( ...@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T> template <typename T>
static bool static bool
...@@ -295,9 +315,16 @@ int main(int argc, char* argv[]) ...@@ -295,9 +315,16 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device:: using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>; DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
...@@ -306,6 +333,7 @@ int main(int argc, char* argv[]) ...@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1) if(omp_get_max_threads() > 1)
{ {
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
...@@ -322,6 +350,25 @@ int main(int argc, char* argv[]) ...@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs); add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
} }
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
} }
if(conv_ptrs.size() <= 0) if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment