Commit 00b03957 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clean up the quantize_ins implementation

parent d096e06a
......@@ -31,7 +31,7 @@ struct fp_conversion
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {targe_type, inputs.front().lens()};
return {targe_type, inputs.front().lens(), inputs.front().strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......
......@@ -90,22 +90,23 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
if(inputs != converted_inputs)
{
auto op = ins->get_operator();
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs);
}
if(ins->get_shape().type() != orig_type)
{
// insert another fp_conversion instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
auto ins_shape = compute_shape(op, converted_inputs);
if (ins_shape.type() != orig_type)
{
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
prog.replace_instruction(ins, ins_orig_type);
// insert another fp_conversion instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
else
{
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::fp_conversion{orig_type}, ins);
prog.replace_instruction(ins, ins_orig_type);
}
}
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs);
}
}
}
......
......@@ -9,7 +9,7 @@ namespace gpu {
shape hip_fp_conversion::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.not_broadcasted().not_transposed();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment