Commit 00b03957 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clean up the quantize_ins implementation

parent d096e06a
...@@ -31,7 +31,7 @@ struct fp_conversion ...@@ -31,7 +31,7 @@ struct fp_conversion
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return {targe_type, inputs.front().lens()}; return {targe_type, inputs.front().lens(), inputs.front().strides()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -90,10 +90,8 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -90,10 +90,8 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
if(inputs != converted_inputs) if(inputs != converted_inputs)
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs); auto ins_shape = compute_shape(op, converted_inputs);
} if (ins_shape.type() != orig_type)
if(ins->get_shape().type() != orig_type)
{ {
// insert another fp_conversion instruction to convert it back // insert another fp_conversion instruction to convert it back
if(ins == std::prev(prog.end())) if(ins == std::prev(prog.end()))
...@@ -107,6 +105,9 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -107,6 +105,9 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
prog.replace_instruction(ins, ins_orig_type); prog.replace_instruction(ins, ins_orig_type);
} }
} }
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs);
}
} }
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_fp_conversion::compute_shape(std::vector<shape> inputs) const shape hip_fp_conversion::compute_shape(std::vector<shape> inputs) const
{ {
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs}.not_broadcasted().not_transposed(); check_shapes{inputs}.packed();
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment