Commit 965ac6fc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 47f6954f
......@@ -44,43 +44,46 @@ instruction_ref insert_quant_ins(program& prog,
ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if (type == shape::int8_type)
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if (scale != 1.0f)
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if (scaled_ins->get_shape().type() != shape::float_type)
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(scaled_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if (shift != 0.0f)
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if (shifted_ins->get_shape().type() != shape::float_type)
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(insert_loc, op::convert{shape::float_type}, shifted_ins);
float_ins = prog.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(shifted_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto clipped_ins = prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, shifted_ins);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, shifted_ins);
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, clipped_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, rounded_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, rounded_ins);
}
else
{
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
......@@ -239,12 +242,14 @@ void quantize_int8(program& prog,
}
else
{
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
}
}
else
{
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
......@@ -341,25 +346,28 @@ void quantize_int8(program& prog,
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if (quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = prog.add_literal(literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
auto l_factor = prog.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor));
if (orig_type == shape::float_type)
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv);
}
else
{
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
auto adjusted_conv =
prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment