concat.cpp 4.49 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

// NOLINTNEXTLINE
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
Paul's avatar
Paul committed
41
#include <migraphx/kernels/ops.hpp>
Paul's avatar
Paul committed
42
43
44
45
#include <args.hpp>

namespace migraphx {

Paul's avatar
Paul committed
46
47
${preamble}

Paul's avatar
Paul committed
48
49
50
51
extern "C" {

__global__ void ${kernel}(${params}) 
{
Paul's avatar
Paul committed
52
53
    transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
        concat<${axis}>(${concat_args})(${post}, y, xs...);
Paul's avatar
Paul committed
54
55
56
57
58
59
60
61
62
63
64
65
66
    });
}

}

} // namespace migraphx

)__migraphx__";

struct concat_compiler : compiler<concat_compiler>
{
    std::vector<std::string> names() const { return {"concat"}; }

Paul's avatar
Paul committed
67
    static std::size_t get_concat_elements(const std::vector<shape>& inputs)
Paul's avatar
Paul committed
68
    {
Paul's avatar
Format  
Paul committed
69
70
        auto total = std::accumulate(
            inputs.begin(), inputs.end(), 0, [](auto x, auto s) { return x + s.elements(); });
Paul's avatar
Paul committed
71
        return total / inputs.size();
Paul's avatar
Paul committed
72
73
74
75
    }

    operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
    {
Paul's avatar
Paul committed
76
        auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
Paul's avatar
Paul committed
77
        hip_compile_options options;
Paul's avatar
Format  
Paul committed
78
79
80
        options.inputs      = inputs;
        options.output      = inputs.back();
        options.params      = "-Wno-float-equal";
Paul's avatar
Paul committed
81
        options.kernel_name = v.get("kernel", "concat_kernel");
Paul's avatar
Format  
Paul committed
82
83
        auto axis           = find_fast_axis(options.inputs);
        auto vec            = vectorize::elements(axis, options.inputs);
Paul's avatar
Paul committed
84
        options.set_launch_params(
Paul's avatar
Paul committed
85
            v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
Paul's avatar
Format  
Paul committed
86
87
88
89
90
91
92
93
94
95
96
        auto src = interpolate_string(
            concat_kernel,
            {{"kernel", options.kernel_name},
             {"params", enum_params(inputs.size(), "void * private_p")},
             {"args", enum_params(inputs.size(), "private_p")},
             {"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
             {"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
             {"post", v.get("post", std::string{"op::id{}"})},
             {"transformers", make_transformer_args(vec)},
             {"preamble", v.get("preamble", std::string{})},
             {"axis", v.at("axis").to<std::string>()}});
Paul's avatar
Paul committed
97
98
99
100
101
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
Paul's avatar
Paul committed
102
103
104
        auto v = op.to_value();
        if(not ins->module_inputs().empty())
        {
Paul's avatar
Format  
Paul committed
105
            auto* pm           = ins->module_inputs().front();
Paul's avatar
Paul committed
106
            v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size() - 1;
Paul's avatar
Format  
Paul committed
107
108
109
            v["preamble"]      = generate_pointwise(*pm, "post_concat");
            v["post"]          = "MIGRAPHX_LIFT(post_concat)";
            v["kernel"]        = "concat_" + generate_name_from_ops(*pm) + "_kernel";
Paul's avatar
Paul committed
110
        }
Paul's avatar
Paul committed
111
112
113
114
115
116
117
        return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
    }
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx