Commit 1d5d035c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent ea9776f5
...@@ -161,7 +161,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -161,7 +161,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
auto s = input->get_shape(); auto s = input->get_shape();
if((s.type() == shape::float_type || s.type() == shape::double_type || if((s.type() == shape::float_type || s.type() == shape::double_type ||
s.type() == shape::int32_type) && s.type() != quant_type) s.type() == shape::int32_type) &&
s.type() != quant_type)
{ {
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
...@@ -211,59 +212,67 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -211,59 +212,67 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// abs(quant_alpha) > 50 (some tmp value set here), we can convert // abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot // it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f; float threshold = 50.0f;
if (fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold) if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{ {
int32_t quant_alpha = static_cast<int32_t>(new_alpha); int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = static_cast<int32_t>(new_beta); int32_t quant_beta = static_cast<int32_t>(new_beta);
shape quant_shape = compute_shape(op::quant_dot{1, 0}, converted_inputs); shape quant_shape = compute_shape(op::quant_dot{1, 0}, converted_inputs);
if (quant_shape.type() == orig_type) if(quant_shape.type() == orig_type)
{ {
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); prog.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
else else
{ {
auto quant_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); auto quant_dot = prog.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot); prog.replace_instruction(ins, op::convert{orig_type}, quant_dot);
} }
} }
// only alpha can be quantized, quantization of beta will cause // only alpha can be quantized, quantization of beta will cause
// big error, so we have to manually do the multiplication and // big error, so we have to manually do the multiplication and
// addition // addition
else if (fabs(new_alpha) >= threshold) else if(fabs(new_alpha) >= threshold)
{ {
int32_t quant_alpha = static_cast<int32_t>(new_alpha); int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = 0; int32_t quant_beta = 0;
if (orig_type == shape::int32_type) if(orig_type == shape::int32_type)
{ {
if (inputs.size() == 2 or dot_op.beta == 0.0f) if(inputs.size() == 2 or dot_op.beta == 0.0f)
{ {
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); prog.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
// if there are 3 inputs, we need to consider the third argument // if there are 3 inputs, we need to consider the third argument
else else
{ {
auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); auto q_dot = prog.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta); std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta);
auto l_beta = prog.add_literal(literal{orig_type, vec_beta}); auto l_beta = prog.add_literal(literal{orig_type, vec_beta});
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); auto beta_c =
prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, q_dot, beta_c); prog.replace_instruction(ins, op::add{}, q_dot, beta_c);
} }
} }
else else
{ {
if (inputs.size() == 2 or dot_op.beta == 0.0f) if(inputs.size() == 2 or dot_op.beta == 0.0f)
{ {
auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); auto q_dot = prog.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, q_dot); prog.replace_instruction(ins, op::convert{orig_type}, q_dot);
} }
// if there are 3 inputs, we need to consider the third argument // if there are 3 inputs, we need to consider the third argument
else else
{ {
auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); auto q_dot = prog.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot); auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot);
std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta); std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta);
auto l_beta = prog.add_literal(literal{oq_dot->get_shape(), vec_beta}); auto l_beta = prog.add_literal(literal{oq_dot->get_shape(), vec_beta});
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); auto beta_c =
prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, q_dot, beta_c); prog.replace_instruction(ins, op::add{}, q_dot, beta_c);
} }
} }
...@@ -272,10 +281,10 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -272,10 +281,10 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{ {
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
std::vector<float> vec_alpha(q_dot->get_shape().elements(), new_alpha); std::vector<float> vec_alpha(q_dot->get_shape().elements(), new_alpha);
if (orig_type == shape::int32_type) if(orig_type == shape::int32_type)
{ {
auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha)); auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha));
if (converted_inputs.size() == 2 or dot_op.beta == 0.0f) if(converted_inputs.size() == 2 or dot_op.beta == 0.0f)
{ {
prog.replace_instruction(ins, op::mul{}, l_alpha, q_dot); prog.replace_instruction(ins, op::mul{}, l_alpha, q_dot);
} }
...@@ -285,7 +294,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -285,7 +294,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
std::vector<float> vec_beta(ins->get_shape().elements(), new_beta); std::vector<float> vec_beta(ins->get_shape().elements(), new_beta);
auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta)); auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta));
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, q_dot); auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, q_dot);
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); auto beta_c =
prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
} }
} }
...@@ -293,7 +303,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -293,7 +303,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{ {
auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot); auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot);
auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha)); auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha));
if (converted_inputs.size() == 2 or dot_op.beta == 0.0f) if(converted_inputs.size() == 2 or dot_op.beta == 0.0f)
{ {
prog.replace_instruction(ins, op::mul{}, l_alpha, oq_dot); prog.replace_instruction(ins, op::mul{}, l_alpha, oq_dot);
} }
...@@ -303,7 +313,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -303,7 +313,8 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
std::vector<float> vec_beta(ins->get_shape().elements(), new_beta); std::vector<float> vec_beta(ins->get_shape().elements(), new_beta);
auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta)); auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta));
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, oq_dot); auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, oq_dot);
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); auto beta_c =
prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
} }
} }
...@@ -324,22 +335,31 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -324,22 +335,31 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
shape quant_shape = compute_shape(op::quant_convolution{}, converted_inputs); shape quant_shape = compute_shape(op::quant_convolution{}, converted_inputs);
std::vector<float> vec_factor(quant_shape.elements(), adjust_factor); std::vector<float> vec_factor(quant_shape.elements(), adjust_factor);
auto fl = prog.add_literal(literal{{orig_type, quant_shape.lens()}, vec_factor}); auto fl = prog.add_literal(literal{{orig_type, quant_shape.lens()}, vec_factor});
if (quant_shape.type() == orig_type) if(quant_shape.type() == orig_type)
{ {
if (adjust_factor == 1.0f) if(adjust_factor == 1.0f)
{ {
prog.replace_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs); prog.replace_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
} }
else else
{ {
auto quant_conv = prog.replace_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs); auto quant_conv = prog.replace_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
prog.replace_instruction(ins, op::mul{}, quant_conv, fl); prog.replace_instruction(ins, op::mul{}, quant_conv, fl);
} }
} }
else else
{ {
auto quant_conv = prog.insert_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs); auto quant_conv = prog.insert_instruction(
if (adjust_factor == 1.0f) ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
if(adjust_factor == 1.0f)
{ {
prog.replace_instruction(ins, op::convert{orig_type}, quant_conv); prog.replace_instruction(ins, op::convert{orig_type}, quant_conv);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment