/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { struct module; namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); bool mlir_enabled() { #ifdef MIGRAPHX_MLIR const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); if(mlir_enabled) { return true; } else { std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env " "var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator." << std::endl; return false; } #else return false; #endif } #ifdef MIGRAPHX_MLIR struct mlir_op { std::string name() const { return "gpu::mlir_op"; } operation op = make_op("convolution"); template static auto reflect(Self& self, F f) { return pack(f(self.op, "op")); } shape compute_shape(std::vector inputs, const std::vector& mods) const { check_shapes{inputs, *this}.packed_or_broadcasted(); if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); if(inputs.size() < 2) MIGRAPHX_THROW("should have at least two inputs."); module_ref mod = mods[0]; auto type = mod->get_output_shapes().front().type(); std::unordered_map ins_shapes; size_t param_cnt = 0; std::vector names = mod->get_parameter_names(); std::sort(names.begin(), names.end()); for(std::string param_name : names) { ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++]; } for(auto ins : iterator_for(*mod)) { if(ins->name() == "@param") { continue; } if(ins->name() == "@literal") { ins_shapes[ins] = ins->get_shape(); continue; } if(ins->name() == "@return") { return ins_shapes[ins->inputs().at(0)].with_type(type); } std::vector input_shapes; input_shapes.resize(ins->inputs().size()); std::transform(ins->inputs().begin(), ins->inputs().end(), input_shapes.begin(), [&](auto in) { return ins_shapes[in]; }); ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); } MIGRAPHX_THROW("No return found in the submodule"); } }; MIGRAPHX_REGISTER_OP(mlir_op); namespace { MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) { if(ins->name() != "convolution" and ins->name() != "quant_convolution") return false; value v = ins->get_operator().to_value(); auto group = v.at("group").to(); if(group != 1) return false; // Avoid MLIR assertion: Index < Length && "Invalid index!" if(ins->get_shape().lens().size() != 4) return false; return true; } struct find_mlir_op { auto matcher() const { auto dot_or_conv = match::skip(match::name("contiguous"))( match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) .bind("gemm_based_op")); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } std::unordered_map create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const { std::unordered_map ins_map; for(auto ins : iterator_for(*pm)) { if(ins->name() != "@literal") { continue; } literal r = ins->get_literal(); instruction_ref literal = mm->add_literal(r); instruction_ref mbcast = mm->add_instruction( make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal); ins_map[ins] = mbcast; } return ins_map; } std::tuple> fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const { std::vector top_inputs; std::vector imm_inputs; size_t input_cnt = 0; for(instruction_ref input : gemm_based_op->inputs()) { std::vector op_stream; while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) { op_stream.push_back(input->get_operator()); input = input->inputs().at(0); } top_inputs.push_back(input); instruction_ref prev_input = mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape()); for(const auto& op : reverse(op_stream)) { prev_input = mm->add_instruction(op, {prev_input}); } imm_inputs.push_back(prev_input); } instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op->get_operator(), imm_inputs); return {new_gemm_based_op, top_inputs}; } // Whitelist supported fusion options, including imposing type constraints // for cases where MLIR only supports an operation (usually a pointwise function) // on particular types. bool is_pointwise_op_supported_by_mlir(const instruction& i) const { using type_t = shape::type_t; const auto& name = i.name(); const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, type_t::half_type, type_t::int8_type, type_t::int32_type, type_t::bool_type}; // Preliminary type check. if(not contains(allowed_types, result_type)) { return false; } const std::initializer_list any_type_ops = {"@literal", "@param", "@return"}; const std::initializer_list no_bool_ops = {"convolution", "quant_convolution", "dot", "quant_dot", "add", "clip", "relu", "sub", "mul", "div", "pow", "where", "quantizelinear", "dequantizelinear", "abs", "neg"}; const std::initializer_list fp_only_ops = {"ceil", "erf", "exp", "floor", "log", "recip", "rsqrt", "sigmoid" "softmax", "tanh"}; bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); if(contains(any_type_ops, name)) return true; if(result_type != type_t::bool_type and contains(no_bool_ops, name)) return true; if(is_float and contains(fp_only_ops, name)) return true; // Only conversions between floating types are known to be unambigiously // supported. if(is_float and name == "convert") { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); }); } return false; } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; auto x_ins = r.instructions["x"]; // input after contiguous auto* pm = ins->module_inputs().front(); auto names = pm->get_parameter_names(); // Whitelist pointwise operators. if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) { return not is_pointwise_op_supported_by_mlir(i); })) return; std::sort(names.begin(), names.end()); module_ref mm = mpm.create_module("mlir_" + pm->name()); mm->set_bypass(); std::unordered_map param_map = create_param_map_with_literals(mm, pm, gemm_based_op->get_shape()); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op); std::transform(names.begin(), names.end(), ins->inputs().begin(), std::inserter(param_map, param_map.end()), [&, &anchor_op = anchor_op](auto name, auto input) { if(input == x_ins) return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), mm->add_parameter(name, input->get_shape())); }); mm->add_return(mm->insert_instructions(mm->end(), pm, param_map)); std::vector inputs; std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { return input != gemm_based_op; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); mpm.get_module().replace_instruction( ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); } }; } // namespace #endif void fuse_mlir::apply(module_pass_manager& mpm) const { #ifdef MIGRAPHX_MLIR match::find_matches(mpm, find_mlir_op{}); #else (void)mpm; #endif } } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx