fused_reduce.cpp 6.87 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>

namespace migraphx {

${preamble}

extern "C" {
__global__ void ${kernel}(${params})
{
    transform_args(make_tensors(), ${transformers})(${args})([](auto y, auto... xs) {
        fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
    });
}
    
}

} // namespace migraphx

)__migraphx__";

static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
    return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
    return get_reduce_elements(to_shapes(inputs));
}

static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
                                                const std::vector<std::size_t>& output_lens)
{
    std::vector<std::size_t> reduce_lens;
    std::transform(output_lens.begin(),
                   output_lens.end(),
                   input_lens.begin(),
                   std::back_inserter(reduce_lens),
                   [](auto x, auto y) -> std::size_t {
                       if(x == y)
                           return 1;
                       else
                           return y;
                   });
    return reduce_lens;
}

Paul's avatar
Format  
Paul committed
87
template <class T>
Paul's avatar
Paul committed
88
89
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
Paul's avatar
Format  
Paul committed
90
    auto lens = s.lens();
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    for(const auto& axis : axes)
        lens[axis] = 1;
    return shape{s.type(), lens};
}

static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
    auto rlens      = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
    const auto init = std::numeric_limits<std::size_t>::max();
    // The minimum stride
    auto min_stride = std::inner_product(
        rlens.begin(),
        rlens.end(),
        inputs.front().strides().begin(),
        init,
        [](auto x, auto y) { return std::min(x, y); },
        [](auto len, auto stride) { return len == 1 ? init : stride; });
    if(min_stride > 2)
        return "lane";
    return "block";
}

struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
Paul's avatar
Format  
Paul committed
115
    std::vector<std::string> names() const { return {"fused_reduce"}; }
Paul's avatar
Paul committed
116
117
118
119

    operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
    {
        auto virtual_inputs = inputs;
Paul's avatar
Format  
Paul committed
120
121
122
        virtual_inputs.push_back(
            get_reduced_shape(inputs.front(), v.at("axes").to_vector<std::size_t>()));
        virtual_inputs     = reduce_dims(virtual_inputs);
Paul's avatar
Paul committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        auto reduced_shape = virtual_inputs.back();
        virtual_inputs.pop_back();

        hip_compile_options options;
        options.inputs         = inputs;
        options.output         = inputs.back();
        options.virtual_inputs = virtual_inputs;
        auto faxis             = find_fast_axis({options.virtual_inputs.front()});
        vectorize vec{};
        // Vectorize if the axis is a reduction axis
        if(options.virtual_inputs.back().lens()[faxis] == 1)
        {
            vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
        }
        auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
        auto nelements = options.virtual_inputs.back().elements();
        auto algo      = v.get("algo", get_reduce_algo(options.virtual_inputs));
        if(algo == "block")
        {
            auto block_size = compute_block_size(relements, 256);
            options.set_launch_params(
                v, compute_global_for(ctx, nelements * block_size, 256), block_size);
        }
        else if(algo == "lane")
        {
            options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
        }
        else
        {
            MIGRAPHX_THROW("Unknown reduce algo: " + algo);
        }
        options.kernel_name  = v.get("kernel", "reduce_kernel");
        std::string identity = "[](auto x) { return x; }";
Paul's avatar
Format  
Paul committed
156
157
158
159
160
161
162
163
164
165
        auto src =
            interpolate_string(simple_reduce_kernel,
                               {{"kernel", options.kernel_name},
                                {"params", enum_params(inputs.size(), "void * private_p")},
                                {"args", enum_params(inputs.size(), "private_p")},
                                {"algo", algo},
                                {"reduced", "decltype(" + generate_make_shape(reduced_shape) + ")"},
                                {"lambda", v.at("lambda").to<std::string>()},
                                {"transformers", make_transformer_args(vec)},
                                {"preamble", v.get("preamble", std::string{})}});
Paul's avatar
Paul committed
166
167
168
169
170
171
172
        options.params += "-Wno-float-equal";
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
        assert(not ins->module_inputs().empty());
Paul's avatar
Format  
Paul committed
173
174
175
176
        auto v        = op.to_value();
        auto* rm      = ins->module_inputs().front();
        v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
        v["lambda"]   = "MIGRAPHX_LIFT(fused_reduce_op)";
Paul's avatar
Paul committed
177
        v["kernel"]   = generate_name_from_ops(*rm) + "_kernel";
Paul's avatar
Format  
Paul committed
178
        return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
Paul's avatar
Paul committed
179
180
181
182
183
    }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx