#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct convert : unary { shape::type_t target_type = shape::half_type; float scale = 1.0f; float shift = 0.0f; template static auto reflect(Self& self, F f) { return pack( f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift")); } shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1); return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; } auto apply() const { return [&](auto x) { float res = scale * x + shift; if(target_type == shape::int8_type) { res = res + 0.5f; res = res > 127.0 ? 127.0 : res; res = res < -128.0 ? -128.0 : res; } return res; }; } convert(shape::type_t t) : target_type{t} {} convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {} convert() {} }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif