Commit 7cd48ef1 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 96c73d70
...@@ -19,22 +19,36 @@ ...@@ -19,22 +19,36 @@
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
template <ck::index_t... Is> struct RequantReluRequant
using S = ck::Sequence<Is...>; {
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
using F32 = float; __host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
}
float scaleGemm_;
float scaleRelu_;
};
using Row = ck::tensor_layout::gemm::RowMajor; template <ck::index_t... Is>
using Col = ck::tensor_layout::gemm::ColumnMajor; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......
...@@ -143,37 +143,6 @@ struct AddHardswishAdd ...@@ -143,37 +143,6 @@ struct AddHardswishAdd
} }
}; };
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the // Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
......
...@@ -171,9 +171,12 @@ check_err(const std::vector<T>& out, ...@@ -171,9 +171,12 @@ check_err(const std::vector<T>& out,
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
if(out[i] != ref[i]) const int64_t out_v = static_cast<int64_t>(out[i]);
const int64_t ref_v = static_cast<int64_t>(ref[i]);
if(out_v != ref_v)
{ {
std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] std::cout << "out[" << i << "] != ref[" << i << "]: " << out_v << " != " << ref_v
<< std::endl << std::endl
<< msg << std::endl; << msg << std::endl;
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment