Commit fbc9dad7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'fp32_fp16_convert' into ins_fp32_fp16

parents c6396935 901a95ef
...@@ -19,29 +19,19 @@ namespace op { ...@@ -19,29 +19,19 @@ namespace op {
struct fp_conversion struct fp_conversion
{ {
bool reduce_precision = true; shape::type_t targe_type = shape::half_type;
std::string name() const { return "fp_conversion"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(reduce_precision)
{
if(inputs.front().type() != shape::float_type)
{
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type float");
}
return {shape::half_type, inputs.front().lens()}; template <class Self, class F>
} static auto reflect(Self& self, F f)
else
{
if(inputs.front().type() != shape::half_type)
{ {
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type fp16"); return pack(f(self.targe_type, "target_type"));
} }
return {shape::float_type, inputs.front().lens()}; std::string name() const { return "fp_conversion"; }
} shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {targe_type, inputs.front().lens()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -80,7 +80,7 @@ void quantize(program& prog) ...@@ -80,7 +80,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end()); auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type) if(ins->get_shape().type() == shape::half_type)
{ {
prog.add_instruction(op::fp_conversion{false}, ins); prog.add_instruction(op::fp_conversion{shape::float_type}, ins);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment