"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a4c2b8899ac2f3ec72abfba9ea07dac3e2eea5ac"
Commit 00b03957 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clean up the quantize_ins implementation

parent d096e06a
...@@ -31,7 +31,7 @@ struct fp_conversion ...@@ -31,7 +31,7 @@ struct fp_conversion
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return {targe_type, inputs.front().lens()}; return {targe_type, inputs.front().lens(), inputs.front().strides()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -90,22 +90,23 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -90,22 +90,23 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
if(inputs != converted_inputs) if(inputs != converted_inputs)
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs); auto ins_shape = compute_shape(op, converted_inputs);
} if (ins_shape.type() != orig_type)
if(ins->get_shape().type() != orig_type)
{
// insert another fp_conversion instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
{ {
auto ins_orig_type = // insert another fp_conversion instruction to convert it back
prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins); if(ins == std::prev(prog.end()))
prog.replace_instruction(ins, ins_orig_type); {
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
{
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
prog.replace_instruction(ins, ins_orig_type);
}
} }
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs);
} }
} }
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_fp_conversion::compute_shape(std::vector<shape> inputs) const shape hip_fp_conversion::compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs}.not_broadcasted().not_transposed(); check_shapes{inputs}.packed();
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment