Commit 557d7aa3 authored by rocking's avatar rocking
Browse files

1. Rename gemm quantization

2. shares the requantization lambda function with conv
parent 42c5e890
add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp)
\ No newline at end of file
...@@ -18,30 +18,12 @@ ...@@ -18,30 +18,12 @@
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
}
float scaleGemm_;
float scaleRelu_;
};
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using CElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
...@@ -67,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -67,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
CShuffleDataType, // typename CShuffleDataType, CShuffleDataType, // typename CShuffleDataType,
PassThrough, // typename AElementwiseOperation, PassThrough, // typename AElementwiseOperation,
PassThrough, // typename BElementwiseOperation, PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // typename CElementwiseOperation, CElementOp, // typename CElementwiseOperation,
GemmDefault, // GemmSpecialization GemmSpec, GemmDefault, // GemmSpecialization GemmSpec,
1, // index_t NumGemmKPrefetchStage, 1, // index_t NumGemmKPrefetchStage,
256, // index_t BlockSize, 256, // index_t BlockSize,
...@@ -100,13 +82,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -100,13 +82,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::
BDataType, ReferenceGemm<ADataType, BDataType, CDataType, float, PassThrough, PassThrough, CElementOp>;
CDataType,
float,
PassThrough,
PassThrough,
RequantReluRequant>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -123,8 +100,7 @@ int main(int argc, char* argv[]) ...@@ -123,8 +100,7 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
float scale_gemm = 0.03; float quant_multiplier = 0.03;
float scale_relu = 1;
if(argc == 4) if(argc == 4)
{ {
...@@ -199,7 +175,7 @@ int main(int argc, char* argv[]) ...@@ -199,7 +175,7 @@ int main(int argc, char* argv[])
auto a_element_op = PassThrough{}; auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{}; auto b_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu}; auto c_element_op = CElementOp{quant_multiplier, ActivationOp{}};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
......
add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment