"vscode:/vscode.git/clone" did not exist on "881a6b58c3b5594d7f2ca1150b5a6779dceee808"
Commit 090ba885 authored by carlushuang's avatar carlushuang
Browse files

add elementwise fusion support

parent 8ce9fe57
......@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
......
#pragma once
#include "data_type_cpu.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace element_wise {
using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t;
struct PassThrough
{
void operator()(float& y, const float& x) const { y = x; }
void operator()(float4_t& y, const float4_t& x) const { y = x; }
void operator()(float8_t& y, const float8_t& x) const { y = x; }
};
struct Add
{
void operator()(float& y, const float& x0, const float& x1) const { y = x0 + x1; }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_add_ps(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_add_ps(x0, x1);
}
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
void operator()(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
}
float alpha_;
float beta_;
};
struct AddRelu
{
void operator()(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = a > 0 ? a : 0;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}
};
#if 0
struct AddHardswish
{
void operator()(float& y, const float& x0, const float& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
#endif
struct AddReluAdd
{
void operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
y = _mm_add_ps(b, x2);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
}
};
#if 0
struct AddHardswishAdd
{
void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
};
#endif
#if 0
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
};
#endif
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x; };
};
template <>
struct UnaryIdentic<float, float, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float4_t, float4_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float8_t, float8_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float8_t, float8_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float4_t, float4_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float8_t, float8_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float8_t, float8_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = abs(x); };
};
template <>
struct UnaryAbs<float4_t, float4_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const
{
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
};
};
template <>
struct UnaryAbs<float8_t, float8_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const
{
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
y = _mm256_and_ps(Mask, x);
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#pragma once
#include "data_type_cpu.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace element_wise {
using float8_t = ck::cpu::float8_t;
using float4_t = ck::cpu::float4_t;
struct PassThrough
{
void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
float Apply(const float& x) const { return x; }
float4_t Apply(const float4_t& x) const { return x; }
float8_t Apply(const float8_t& x) const { return x; }
static constexpr char* Name() { return "PassThrough"; }
};
struct Add
{
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const { return x0 + x1; }
float4_t Apply(const float4_t& x0, const float4_t& x1) const { return _mm_add_ps(x0, x1); }
float8_t Apply(const float8_t& x0, const float8_t& x1) const { return _mm256_add_ps(x0, x1); }
static constexpr char* Name() { return "Add"; }
};
struct Relu
{
void operator()(float& y, const float& x) const { y = Apply(x); }
void operator()(float4_t& y, const float4_t& x) const { y = Apply(x); }
void operator()(float8_t& y, const float8_t& x) const { y = Apply(x); }
float Apply(const float& x) const { return x > 0 ? x : 0; }
float4_t Apply(const float4_t& x) const { return _mm_max_ps(x, _mm_setzero_ps()); }
float8_t Apply(const float8_t& x) const { return _mm256_max_ps(x, _mm256_setzero_ps()); }
static constexpr char* Name() { return "Relu"; }
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const { return alpha_ * x0 + beta_ * x1; }
float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
return _mm_add_ps(_mm_mul_ps(x0, _mm_set1_ps(alpha_)), _mm_mul_ps(x1, _mm_set1_ps(beta_)));
}
float8_t Apply(const float8_t& x0, const float8_t& x1) const
{
return _mm256_add_ps(_mm256_mul_ps(x0, _mm256_set1_ps(alpha_)),
_mm256_mul_ps(x1, _mm256_set1_ps(beta_)));
}
static constexpr char* Name() { return "AlphaBetaAdd"; }
float alpha_;
float beta_;
};
struct AddRelu
{
void operator()(float& y, const float& x0, const float& x1) const { y = Apply(x0, x1); }
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1) const
{
y = Apply(x0, x1);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1) const
{
y = Apply(x0, x1);
}
float Apply(const float& x0, const float& x1) const
{
const float a = x0 + x1;
return a > 0 ? a : 0;
}
float4_t Apply(const float4_t& x0, const float4_t& x1) const
{
return _mm_max_ps(_mm_add_ps(x0, x1), _mm_setzero_ps());
}
float8_t Apply(const float8_t& x0, const float8_t& x1) const
{
return _mm256_max_ps(_mm256_add_ps(x0, x1), _mm256_setzero_ps());
}
static constexpr char* Name() { return "AddRelu"; }
};
struct AddReluAdd
{
void operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
float c = b + x2;
y = c;
}
void operator()(float4_t& y, const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
y = _mm_add_ps(b, x2);
}
void operator()(float8_t& y, const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
y = _mm256_add_ps(b, x2);
}
static constexpr char* Name() { return "AddReluAdd"; }
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x; };
};
template <>
struct UnaryIdentic<float, float, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float4_t, float4_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float4_t, float4_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(x, _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<float8_t, float8_t, false>
{
UnaryIdentic(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<float8_t, float8_t, true>
{
UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(x, _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float& y, const float& x) const { y = x * x / type_convert<float>(divider_); };
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float4_t, float4_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const { y = _mm_mul_ps(x, x); };
};
template <>
struct UnarySquare<float4_t, float4_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float4_t& y, const float4_t& x) const
{
y = _mm_div_ps(_mm_mul_ps(x, x), _mm_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<float8_t, float8_t, false>
{
UnarySquare(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_mul_ps(x, x); };
};
template <>
struct UnarySquare<float8_t, float8_t, true>
{
UnarySquare(const int32_t divider = 1) { divider_ = divider; };
void operator()(float8_t& y, const float8_t& x) const
{
y = _mm256_div_ps(_mm256_mul_ps(x, x), _mm256_set1_ps(static_cast<float>(divider_)));
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float& y, const float& x) const { y = abs(x); };
};
template <>
struct UnaryAbs<float4_t, float4_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float4_t& y, const float4_t& x) const
{
__m128 Mask = _mm_castsi128_ps(_mm_set1_epi32(~0x80000000));
y = _mm_and_ps(Mask, x);
};
};
template <>
struct UnaryAbs<float8_t, float8_t>
{
UnaryAbs(const int32_t divider = 1) { (void)divider; };
void operator()(float8_t& y, const float8_t& x) const
{
__m256 Mask = _mm256_castsi256_ps(_mm256_set1_epi32(~0x80000000));
y = _mm256_and_ps(Mask, x);
};
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
void operator()(float& y, const float& x) const { y = sqrtf(x); };
};
template <>
struct UnarySqrt<float4_t, float4_t>
{
void operator()(float4_t& y, const float4_t& x) const { y = _mm_sqrt_ps(x); };
};
template <>
struct UnarySqrt<float8_t, float8_t>
{
void operator()(float8_t& y, const float8_t& x) const { y = _mm256_sqrt_ps(x); };
};
} // namespace element_wise
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
......@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(m_per_blk, k_per_blk);
}
else
{
// A : K, M
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(m_per_blk,
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
}
}
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(
k_per_blk,
math::integer_least_multiple(n_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
......@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
// c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
// ck::make_multi_index(i_mc, i_nc));
}
else
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
......@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc,
a_block_buf,
......@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
}
}
if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
// if constexpr(UseCLocalBuffer)
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
}
}
......@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
......@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
......@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{
c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
blockwise_gemm.Run(a_block_desc,
......@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
}
}
}
......
......@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType,
......@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
......@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......
......@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......
......@@ -12,6 +12,11 @@
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using F32 = float;
using F16 = ck::half_t;
......@@ -22,6 +27,7 @@ namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
......@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::Relu;
#endif
template <typename T>
static bool
......@@ -295,9 +315,16 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::cpu::device::DeviceConvFwdPtr<PassThrough, PassThrough, Relu>;
#endif
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
......@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
......@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
}
if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment