pointwise.cpp 4.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>

#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>

namespace migraphx {

${preamble}

extern "C" {
__global__ void kernel(${params}) 
{
    pointwise(${lambda}, ${args});
}
    
}

} // namespace migraphx

)__migraphx__";

struct pointwise_compiler : compiler<pointwise_compiler>
{
    std::vector<std::string> names() const { return {"pointwise"}; }

    static std::size_t oversubscribe(const std::vector<shape>& inputs)
    {
Paul's avatar
Paul committed
46
47
48
        // if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
        //     return 1;
        // else
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            return 4;
    }
    static std::size_t vectorize_elements(const std::vector<shape>& inputs)
    {
        std::size_t n = inputs.front().elements();
        if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
               return s.packed() or s.broadcasted();
           }))
        {
            if((n % 4) == 0)
                return n / 4;
            else if((n % 2) == 0)
                return n / 2;
        }
        return n;
    }
    operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
    {
        hip_compile_options options;
        options.set_launch_params(
            v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
        options.inputs         = inputs;
        options.output         = inputs.back();
        options.virtual_inputs = reduce_dims(inputs);
        options.params         = "-Wno-float-equal";
        auto src               = interpolate_string(pointwise_kernel,
                                      {{"params", enum_params(inputs.size(), "void * private_p")},
                                       {"args", enum_params(inputs.size(), "private_p")},
                                       {"lambda", v.at("lambda").to<std::string>()},
                                       {"preamble", v.get("preamble", std::string{})}});
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
    {
        assert(not ins->module_inputs().empty());
        auto* pm = ins->module_inputs().front();
        run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
        cpp_generator g;
        g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
        g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
        g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
        g.add_point_op("sign",
                       "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
        g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
        g.add_point_op("less", "migraphx::abs(${0} < ${1})");
        g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
        g.add_point_op("not", "migraphx::abs(not ${0})");
        // Add explict conversions
        g.fresult(
            [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
        auto name = g.create_function(
            g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
        std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
        return replace(
            compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}}));
    }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx