Commit 852a517a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 34493a8d
......@@ -14,7 +14,8 @@ instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f, float shift = 0.0f)
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
......@@ -110,7 +111,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{
// For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if (!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name);
}))
{
......@@ -145,14 +146,15 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if (ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
{
quant_type = shape::int32_type;
}
auto param = int8_param[param_index++];
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type || s.type() == shape::int32_type)
if(s.type() == shape::float_type || s.type() == shape::double_type ||
s.type() == shape::int32_type)
{
// if the input is a convert operator, uses its input
// as its current input
......@@ -160,19 +162,20 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
if(input->name() == "convert")
{
auto tmp_ins = input->inputs().front();
if (tmp_ins->get_shape().type() == quant_type)
if(tmp_ins->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
}
else
{
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
}
}
else
{
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
......@@ -210,9 +213,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
prog.replace_instruction(ins, op, converted_inputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -13,9 +13,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift);
return args[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment