Unverified Commit 250d3c87 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into ck-flash-attn

parents 135eb63e f3939b99
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static std::vector<std::string> parse_targets() { return split_string(MIGRAPHX_GPU_TARGETS, ';'); }
const std::vector<std::string>& get_targets()
{
static auto result = parse_targets();
return result;
}
std::string get_targets_as_string() { return join_strings(get_targets(), ", "); }
static int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return props.gcnArchName;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
const std::vector<std::string>& get_targets();
std::string get_targets_as_string();
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
...@@ -103,7 +103,10 @@ struct mlir_op ...@@ -103,7 +103,10 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
return ins_shapes[ins->inputs().at(0)].with_type(type); auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -119,6 +122,33 @@ struct mlir_op ...@@ -119,6 +122,33 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true; return true;
} }
struct find_mlir_op struct find_mlir_fused_ops
{ {
auto matcher() const auto matcher() const
{ {
...@@ -163,34 +193,6 @@ struct find_mlir_op ...@@ -163,34 +193,6 @@ struct find_mlir_op
return ins_map; return ins_map;
} }
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints // Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function) // for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types. // on particular types.
...@@ -236,8 +238,7 @@ struct find_mlir_op ...@@ -236,8 +238,7 @@ struct find_mlir_op
"log", "log",
"recip", "recip",
"rsqrt", "rsqrt",
// There are bugs in MLIR right now for models using sigmoid so disable it for now "sigmoid",
// "sigmoid",
"softmax", "softmax",
"tanh", "tanh",
}; };
...@@ -282,9 +283,9 @@ struct find_mlir_op ...@@ -282,9 +283,9 @@ struct find_mlir_op
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) { [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -301,20 +302,115 @@ struct find_mlir_op ...@@ -301,20 +302,115 @@ struct find_mlir_op
} }
}; };
struct find_mlir_standalone_op
{
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option)
{
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif #endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{}); if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise ...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)( return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front(); auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
} }
}; };
......
...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr) ...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr); auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess) if(status != hipSuccess)
return false; return false;
return attr.memoryType == hipMemoryTypeDevice; return attr.type == hipMemoryTypeDevice;
} }
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
......
...@@ -58,6 +58,8 @@ struct hiprtc_src_file ...@@ -58,6 +58,8 @@ struct hiprtc_src_file
} }
}; };
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc( MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch); std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
......
...@@ -84,8 +84,10 @@ struct miopen_convolution ...@@ -84,8 +84,10 @@ struct miopen_convolution
{ {
check_shapes{inputs, op}.has(4); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts( check_shapes{conv_inputs, *this}
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); .max_ndims(5)
.packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}})
.same_layout();
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
......
...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, ...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx, MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<shape>& inputs); const std::vector<shape>& inputs,
bool exhaustive);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&, const operation&,
bool exhaustive) const bool exhaustive) const
{ {
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(ctx, *smod, shapes); return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
} }
}; };
......
...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler> ...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler>
// coord_trans_mode // coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>(); auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f; float rois_offset = (ctm == "half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset); options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale // spatial_scale
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reshape_lazy.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
...@@ -89,7 +90,6 @@ struct miopen_apply ...@@ -89,7 +90,6 @@ struct miopen_apply
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
...@@ -115,6 +115,7 @@ struct miopen_apply ...@@ -115,6 +115,7 @@ struct miopen_apply
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op(); add_select_module_op();
add_reshape_lazy_op();
} }
void copy_params() const void copy_params() const
...@@ -376,6 +377,32 @@ struct miopen_apply ...@@ -376,6 +377,32 @@ struct miopen_apply
return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs()); return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
}); });
} }
/**
* Adds reshape lazy to reshape ops that can be aliased instead of copied.
* `gpu::contiguous` are added before and after the reshape; these contiguous
* instructions can be removed by the eliminate_contiguous pass.
*/
void add_reshape_lazy_op()
{
apply_map.emplace("reshape", [=](instruction_ref ins) {
std::vector<instruction_ref> before_contiguous_args = ins->inputs();
auto before_alloc = insert_allocation(ins, std::prev(ins)->get_shape());
before_contiguous_args.push_back(before_alloc);
auto before_contig =
mod->insert_instruction(ins, make_op("gpu::contiguous"), {before_contiguous_args});
auto new_lazy_reshape = mod->insert_instruction(
ins,
make_op("reshape_lazy", {{"dims", {ins->get_operator().to_value().at("dims")}}}),
before_contig);
std::vector<instruction_ref> after_contiguous_args = {new_lazy_reshape};
auto after_alloc = insert_allocation(new_lazy_reshape, new_lazy_reshape->get_shape());
after_contiguous_args.push_back(after_alloc);
return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args);
});
}
}; };
void lowering::apply(module_pass_manager& mpm) const void lowering::apply(module_pass_manager& mpm) const
......
...@@ -22,7 +22,9 @@ ...@@ -22,7 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#include <ostream>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
#include <mlir-c/IR.h> #include <mlir-c/IR.h>
...@@ -33,6 +35,7 @@ ...@@ -33,6 +35,7 @@
#include <mlir-c/Dialect/Rock.h> #include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
...@@ -69,6 +72,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -69,6 +72,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -93,6 +97,8 @@ struct mlir_handle ...@@ -93,6 +97,8 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return not(x == y); } friend bool operator!=(ptr x, ptr y) { return not(x == y); }
explicit operator bool() const noexcept { return obj != ptr(); }
T obj{}; T obj{};
}; };
...@@ -176,13 +182,85 @@ std::string mlir_print(F f, T x) ...@@ -176,13 +182,85 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
struct mlir_logger
{
std::stringstream ss;
mlir_context* ctx;
std::optional<MlirDiagnosticHandlerID> id;
mlir_logger() : ctx(nullptr), id(std::nullopt) {}
mlir_logger(mlir_context* context) : ctx(context)
{
id =
mlirContextAttachDiagnosticHandler(ctx->get(), mlir_diagnostic_print_cb, this, nullptr);
}
~mlir_logger()
{
if(id.has_value())
mlirContextDetachDiagnosticHandler(ctx->get(), *id);
}
mlir_logger(const mlir_logger& other) = delete;
mlir_logger& operator=(const mlir_logger& other) = delete;
mlir_logger(mlir_logger&& other) noexcept
: ss(std::move(other.ss)), ctx(other.ctx), id(other.id)
{
other.ctx = nullptr;
other.id = std::nullopt;
}
mlir_logger& operator=(mlir_logger other) noexcept
{
std::swap(ss, other.ss);
std::swap(ctx, other.ctx);
std::swap(id, other.id);
return *this;
}
std::string str() const { return ss.str(); }
void clear() { ss = std::stringstream{}; }
static MlirLogicalResult mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger);
MlirLogicalResult handle(MlirDiagnostic diag);
};
MlirLogicalResult mlir_logger::mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger)
{
return reinterpret_cast<mlir_logger*>(logger)->handle(diag);
}
MlirLogicalResult mlir_logger::handle(MlirDiagnostic diag)
{
MlirDiagnosticSeverity sev = mlirDiagnosticGetSeverity(diag);
switch(sev)
{
case MlirDiagnosticSeverity::MlirDiagnosticError: ss << "Error: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticWarning: ss << "Warning: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticNote: ss << "Note: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticRemark: ss << "Remark: "; break;
}
mlir_print(mlirDiagnosticPrint, diag, [&](auto s) { ss << s; });
ss << std::endl;
for(intptr_t i = 0, e = mlirDiagnosticGetNumNotes(diag); i < e; ++i)
{
(void)handle(mlirDiagnosticGetNote(diag, i));
}
return mlirLogicalResultSuccess();
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)), /*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location)),
logger(&ctx)
{ {
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get()); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
...@@ -610,21 +688,49 @@ struct mlir_program ...@@ -610,21 +688,49 @@ struct mlir_program
} }
} }
void run_high_level_pipeline() MIGRAPHX_TIDY_CONST void run_high_level_pipeline()
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
if(mlirLogicalResultIsFailure(
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()))))
{
std::string error = "Invalid MLIR created: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
void run_backend_pipeline() MIGRAPHX_TIDY_CONST void run_backend_pipeline()
{ {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
const size_t trace = value_of(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
auto mod_op = mlirModuleGetOperation(mmodule.get());
if(trace >= 2)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
} }
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST if(mlirLogicalResultIsFailure(mlirPassManagerRunOnOp(pm_back.get(), mod_op)))
{
std::string error = "MLIR backend compilation failed: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
}
code_object_op compile(const value& solution)
{ {
// 1st pipeline to call // 1st pipeline to call
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -678,12 +784,15 @@ struct mlir_program ...@@ -678,12 +784,15 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str); MIGRAPHX_THROW("Failed setting tuning key: " + *str);
} }
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive)
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
mlir_tuning_space params{ auto tuning_mode =
mlirRockTuningSpaceCreate(mmodule.get(), RocmlirTuningParamSetKindFull)}; exhaustive ? RocmlirTuningParamSetKindFull : RocmlirTuningParamSetKindQuick;
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) for(auto i : range(mlirRockTuningGetNumParams(params.get())))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
...@@ -695,7 +804,8 @@ struct mlir_program ...@@ -695,7 +804,8 @@ struct mlir_program
if(perf_key_bytes > perf_key.size()) if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) + MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long"); " bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes); tc.solutions.emplace_back(
std::string(perf_key.begin(), perf_key.begin() + perf_key_bytes));
} }
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key; std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
size_t tuning_key_bytes = size_t tuning_key_bytes =
...@@ -717,7 +827,8 @@ struct mlir_program ...@@ -717,7 +827,8 @@ struct mlir_program
if(not tuning_cfg_path.empty()) if(not tuning_cfg_path.empty())
{ {
std::vector<std::string> tokens = split_string(prob_config, '\t'); std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1]; std::string prob = tokens[2];
if(starts_with(prob, "conv")) if(starts_with(prob, "conv"))
{ {
tuning_cfg_path += ".conv"; tuning_cfg_path += ".conv";
...@@ -727,6 +838,8 @@ struct mlir_program ...@@ -727,6 +838,8 @@ struct mlir_program
tuning_cfg_path += ".gemm"; tuning_cfg_path += ".gemm";
} }
std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app); std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app);
prob =
trim(prob, [](unsigned char c) { return (c == '\0') or (std::isspace(c) != 0); });
tuning_cfg << prob << std::endl; tuning_cfg << prob << std::endl;
} }
} }
...@@ -799,6 +912,7 @@ struct mlir_program ...@@ -799,6 +912,7 @@ struct mlir_program
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
mlir_module mmodule; mlir_module mmodule;
mlir_logger logger;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch = ""; std::string target_arch = "";
...@@ -867,15 +981,22 @@ code_object_op compile_mlir(const context& migraphx_ctx, ...@@ -867,15 +981,22 @@ code_object_op compile_mlir(const context& migraphx_ctx,
adjust_param_shapes(m, to_shapes(inputs)); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl; std::cout << m << std::endl;
}
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution); auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
...@@ -898,15 +1019,17 @@ instruction_ref insert_mlir(module& m, ...@@ -898,15 +1019,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs) module m,
const std::vector<shape>& inputs,
bool exhaustive)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
return mp.get_tuning_config(); return mp.get_tuning_config(exhaustive);
} }
#else #else
...@@ -935,7 +1058,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins ...@@ -935,7 +1058,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&) tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&, bool)
{ {
return {}; return {};
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifdef __HIP_DEVICE_COMPILE__
#error \
"Device compilation not allowed for migraphx_gpu. Do not link with hip::device. Device code should go into migraphx_device or migraphx_kernels"
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
...@@ -45,40 +46,42 @@ struct layernorm_base ...@@ -45,40 +46,42 @@ struct layernorm_base
} }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = N;
if(not mods.empty()) if(not mods.empty())
{ {
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs += pm->get_parameter_names().size() - 1;
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs);
auto s = inputs.at(0); auto s = inputs.front();
auto t = s.type(); auto t = s.type();
if(not mods.empty()) if(not mods.empty())
t = mods.front()->get_output_shapes().front().type(); t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{ // Scalar output if all inputs are scalar
return s; if(inputs.front().elements() == 1 and
} all_of(inputs, [](const auto& ss) { return ss.scalar(); }))
else if(s.broadcasted()) return inputs.front();
{ auto l_s = shape::from_permutation(
return {t, s.lens()}; t, s.lens(), find_permutation(std::vector<shape>(inputs.begin(), inputs.begin() + N)));
} // just prelayernorm or preadd_layernorm
else if(nargs <= N)
{ return l_s;
return s.with_lens(t, s.lens()); // else, layernorm + pointwise fusion, preserve layout of fused op
} std::vector<shape> lp_s(inputs.begin() + N, inputs.end());
lp_s.insert(lp_s.begin(), l_s);
return shape::from_permutation(t, s.lens(), find_permutation(lp_s));
} }
}; };
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 1>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1> struct add_layernorm : layernorm_base<add_layernorm, 2>
{ {
std::string name() const { return "gpu::preadd_layernorm"; } std::string name() const { return "gpu::preadd_layernorm"; }
}; };
......
This diff is collapsed.
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment