"test/vscode:/vscode.git/clone" did not exist on "963224f50b28ac2996610e38127f4b569c8c36da"
quant_convolution.cpp 1.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP

#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
15
16
17

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
18
namespace op {
19

20
struct convert : unary<convert>
21
{
22
23
24
    shape::type_t target_type = shape::half_type;
    float scale               = 1.0f;
    float shift               = 0.0f;
25

26
27
    template <class Self, class F>
    static auto reflect(Self& self, F f)
28
    {
29
30
        return pack(
            f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
31
32
    }

33
    shape compute_shape(std::vector<shape> inputs) const
34
    {
35
36
        check_shapes{inputs, *this}.has(1);
        return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
37
38
    }

39
    auto apply() const
40
    {
41
42
43
44
45
46
47
48
49
50
51
52
        return [&](auto x) {
            float res = scale * x + shift;
            if(target_type == shape::int8_type)
            {
                int factor = (res >= 0.0f) ? 1 : -1;
                res        = res + factor * 0.5f;
                res        = res > 127.0f ? 127.0f : res;
                res        = res < -128.0f ? -128.0f : res;
            }

            return res;
        };
53
    }
54

55
56
57
58
    convert(shape::type_t t) : target_type{t} {}
    convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
    convert() {}
};
59

60
} // namespace op
61
62
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
63
64

#endif