convert.cpp 1.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP

#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
15
16
17

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
18
19
20
namespace op {

struct convert : unary<convert>
21
{
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    shape::type_t target_type = shape::half_type;
    float scale               = 1.0f;
    float shift               = 0.0f;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
            f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
    }

    auto apply() const
    {
        return [&](auto x) {
            float res = scale * x + shift;
43
44
            if(target_type == shape::int8_type)
            {
45
46
47
                res = res + 0.5f;
                res = res > 127.0 ? 127.0 : res;
                res = res < -128.0 ? -128.0 : res;
48
            }
49

50
51
52
53
54
55
56
57
58
59
            return res;
        };
    }

    convert(shape::type_t t) : target_type{t} {}
    convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
    convert() {}
};

} // namespace op
60
61
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
62
63

#endif