"profiler/vscode:/vscode.git/clone" did not exist on "a2969aa8b6c3da78fe31a54a589c393480eabe77"
Commit 557d7aa3 authored by rocking's avatar rocking
Browse files

1. Rename gemm quantization

2. shares the requantization lambda function with conv
parent 42c5e890
add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp)
\ No newline at end of file
......@@ -18,30 +18,12 @@
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
}
float scaleGemm_;
float scaleRelu_;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using CElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
using ADataType = int8_t;
using BDataType = int8_t;
......@@ -67,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
CShuffleDataType, // typename CShuffleDataType,
PassThrough, // typename AElementwiseOperation,
PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // typename CElementwiseOperation,
CElementOp, // typename CElementwiseOperation,
GemmDefault, // GemmSpecialization GemmSpec,
1, // index_t NumGemmKPrefetchStage,
256, // index_t BlockSize,
......@@ -100,13 +82,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
float,
PassThrough,
PassThrough,
RequantReluRequant>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, float, PassThrough, PassThrough, CElementOp>;
int main(int argc, char* argv[])
{
......@@ -123,8 +100,7 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
float scale_gemm = 0.03;
float scale_relu = 1;
float quant_multiplier = 0.03;
if(argc == 4)
{
......@@ -199,7 +175,7 @@ int main(int argc, char* argv[])
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu};
auto c_element_op = CElementOp{quant_multiplier, ActivationOp{}};
// do GEMM
auto gemm = DeviceGemmInstance{};
......
add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment