concat.cpp 8.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
Paul's avatar
Paul committed
30
#include <migraphx/algorithm.hpp>
31
32
33
34
35
36
37
38
39
40
41

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

// NOLINTNEXTLINE
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
42
#include <migraphx/kernels/ops.hpp>
43
44
45
46
#include <args.hpp>

namespace migraphx {

47
48
${preamble}

49
50
extern "C" {

Paul Fultz II's avatar
Paul Fultz II committed
51
MIGRAPHX_GLOBAL void ${kernel}(${params}) 
52
{
53
54
    transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
        concat<${axis}>(${concat_args})(${post}, y, xs...);
55
56
57
58
59
60
61
62
63
64
65
    });
}

}

} // namespace migraphx

)__migraphx__";

struct concat_compiler : compiler<concat_compiler>
{
Paul's avatar
Paul committed
66
    std::vector<std::string> names() const { return {"fused_concat", "concat"}; }
Paul's avatar
Paul committed
67
68
69
70
71
72
73
74
75

    operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
    {
        hip_compile_options options;
        options.inputs      = inputs;
        options.output      = inputs.back();
        options.params      = "-Wno-float-equal";
        options.kernel_name = v.get("kernel", "concat_kernel");
        auto axis           = find_fast_axis(options.inputs);
Paul's avatar
Format  
Paul committed
76
77
        auto op_names       = v.at("ops").to_vector<std::string>();
        auto args           = v.at("args");
Paul's avatar
Paul committed
78
79
80
81
        vectorize vec{};
        if(axis != v.at("axis").to<std::size_t>())
            vec = vectorize::elements(ctx, axis, options.inputs);
        auto nelements_per_op = options.inputs.back().elements() / op_names.size();
Paul's avatar
Format  
Paul committed
82
        options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
Paul's avatar
Paul committed
83
84
        std::vector<std::string> concat_params;
        std::vector<std::string> concat_args;
Paul's avatar
Paul committed
85
        for(auto i : range(op_names.size()))
Paul's avatar
Paul committed
86
        {
Paul's avatar
Paul committed
87
            const auto& name = op_names[i];
Paul's avatar
Format  
Paul committed
88
            auto n      = args.at(name).to<std::size_t>();
Paul's avatar
Format  
Paul committed
89
            auto prefix      = to_c_id(name + std::to_string(i) + "_concat_x");
Paul's avatar
Paul committed
90
91
            transform(range(n), std::back_inserter(concat_params), [&](auto j) {
                return "auto " + prefix + std::to_string(j);
Paul's avatar
Paul committed
92
93
            });
            std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"};
Paul's avatar
Paul committed
94
95
            transform(range(n), std::back_inserter(pack_args), [&](auto j) {
                return prefix + std::to_string(j);
Paul's avatar
Paul committed
96
97
98
            });
            concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
        }
Paul's avatar
Paul committed
99
        auto src = interpolate_string(concat_kernel,
Paul's avatar
Format  
Paul committed
100
101
102
103
104
105
106
107
108
                                      {{"kernel", options.kernel_name},
                                       {"params", enum_params(inputs.size(), "void * private_p")},
                                       {"args", enum_params(inputs.size(), "private_p")},
                                       {"concat_params", join_strings(concat_params, ", ")},
                                       {"concat_args", join_strings(concat_args, ", ")},
                                       {"post", v.get("post", std::string{"op::id{}"})},
                                       {"transformers", make_transformer_args(vec)},
                                       {"preamble", v.get("preamble", std::string{})},
                                       {"axis", v.at("axis").to<std::string>()}});
Paul's avatar
Paul committed
109
110
111
112
113
114
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
        auto v = op.to_value();
Paul's avatar
Format  
Paul committed
115
        if(op.name() == "fused_concat")
Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        {
            std::unordered_map<std::string, std::string> mod_names_lookup;
            transform(range(ins->module_inputs().size()),
                      std::inserter(mod_names_lookup, mod_names_lookup.end()),
                      [&](auto i) {
                          return std::make_pair(ins->module_inputs()[i]->name(),
                                                "pointwise" + std::to_string(i));
                      });
            v["preamble"] = transform_accumulate(
                ins->module_inputs().begin(),
                ins->module_inputs().end(),
                std::string{},
                std::plus<>{},
                [&](module_ref mod) {
                    return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
                });
            std::vector<std::string> mod_names;
            std::transform(ins->module_inputs().begin(),
                           ins->module_inputs().end() - 1,
                           std::back_inserter(mod_names),
                           [&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
Paul's avatar
Format  
Paul committed
137
            v["ops"]            = mod_names;
Paul's avatar
Paul committed
138
139
140
141
142
143
144
145
146
147
            module_ref last_mod = ins->module_inputs().back();
            v["post"]           = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
            std::unordered_map<std::string, std::size_t> mod_args;
            std::transform(ins->module_inputs().begin(),
                           ins->module_inputs().end() - 1,
                           std::inserter(mod_args, mod_args.end()),
                           [&](module_ref mod) {
                               const auto& name = mod_names_lookup.at(mod->name());
                               return std::make_pair(name, mod->get_parameter_names().size());
                           });
Paul's avatar
Format  
Paul committed
148
            v["args"]        = mod_args;
Paul's avatar
Paul committed
149
150
151
152
153
154
155
156
157
158
159
160
161
            auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
                                                    ins->module_inputs().end() - 1,
                                                    std::string{},
                                                    std::plus<>{},
                                                    [&](module_ref mod) -> std::string {
                                                        auto name = generate_name_from_ops(*mod);
                                                        if(name.empty())
                                                            return "";
                                                        return name + "_";
                                                    });
            v["kernel"]      = prefix_name + "concat_" +
                          generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
        }
Paul's avatar
Format  
Paul committed
162
        else if(op.name() == "concat")
Paul's avatar
Paul committed
163
164
165
166
        {
            auto concat_inputs = ins->inputs().size() - 1;
            if(not ins->module_inputs().empty())
            {
Paul's avatar
Format  
Paul committed
167
                auto* pm      = ins->module_inputs().front();
Paul's avatar
Paul committed
168
                concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
Paul's avatar
Format  
Paul committed
169
170
171
                v["preamble"] = generate_pointwise(*pm, "post_concat");
                v["post"]     = "MIGRAPHX_LIFT(post_concat)";
                v["kernel"]   = "concat_" + generate_name_from_ops(*pm) + "_kernel";
Paul's avatar
Paul committed
172
173
            }
            std::vector<std::string> mod_names(concat_inputs, "op::id{}");
Paul's avatar
Format  
Paul committed
174
            v["ops"]                                              = mod_names;
Paul's avatar
Paul committed
175
            std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
Paul's avatar
Format  
Paul committed
176
            v["args"]                                             = mod_args;
Paul's avatar
Paul committed
177
        }
Paul's avatar
Paul committed
178
179
180
181
        return compile_op(ctx, to_shapes(ins->inputs()), v);
    }
};

182
183
184
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx