Commit 901a95ef authored by Shucai Xiao's avatar Shucai Xiao
Browse files

minor change of the fp_conversion operator to support more scenarios.

parent 45dddfa7
......@@ -19,29 +19,19 @@ namespace op {
struct fp_conversion
{
bool reduce_precision = true;
std::string name() const { return "fp_conversion"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(reduce_precision)
{
if(inputs.front().type() != shape::float_type)
{
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type float");
}
shape::type_t targe_type = shape::half_type;
return {shape::half_type, inputs.front().lens()};
}
else
{
if(inputs.front().type() != shape::half_type)
template <class Self, class F>
static auto reflect(Self& self, F f)
{
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type fp16");
return pack(f(self.targe_type, "target_type"));
}
return {shape::float_type, inputs.front().lens()};
}
std::string name() const { return "fp_conversion"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {targe_type, inputs.front().lens()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......
......@@ -80,7 +80,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type)
{
prog.add_instruction(op::fp_conversion{false}, ins);
prog.add_instruction(op::fp_conversion{shape::float_type}, ins);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment