Commit 36bb977b authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into rand_uniform

parents d626a09e c65ab678
......@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args)
{
impl->push_back({builtin::returns{}, {}, std::move(args)});
shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
......@@ -1011,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
module& module::sort()
{
auto implicit_deps = calc_implicit_deps();
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
auto ins_inputs = ins->inputs();
if(implicit_deps.find(ins) != implicit_deps.end())
{
auto ins_implict_inputs = implicit_deps.at(ins);
ins_inputs.insert(
ins_inputs.end(), ins_implict_inputs.begin(), ins_implict_inputs.end());
}
for(auto child : ins_inputs)
{
if(not contains(this->impl->instructions, child))
{
......
......@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first;
}
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -40,13 +40,14 @@
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <queue>
#include <sstream>
#include <algorithm>
#include <set>
#include <unordered_map>
#include <utility>
#include <unordered_set>
#include <map>
#include <cassert>
......@@ -1191,11 +1192,19 @@ void program::remove_unused_modules()
program& program::sort()
{
for(auto& pp : this->impl->modules)
std::queue<migraphx::module_ref> mqueue;
mqueue.push(get_main_module());
while(not mqueue.empty())
{
pp.second.sort();
module_ref current_mod = mqueue.front();
current_mod->sort();
mqueue.pop();
auto child_mods = current_mod->get_sub_modules(true);
for(auto& sub_mod : child_mods)
{
mqueue.push(sub_mod);
}
}
return *this;
}
......
......@@ -23,14 +23,24 @@
#####################################################################################
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
if(MIGRAPHX_ENABLE_PYTHON)
include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION})
add_library(migraphx_py_${PYTHON_VERSION} py.cpp)
target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace py = pybind11;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern "C" program migraphx_load_py(const std::string& filename);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const std::string& python_path()
{
static const auto path = dynamic_loader::path(&migraphx_load_py).parent_path().string();
return path;
}
static py::dict run_file(const std::string& file)
{
py::object scope = py::module_::import("__main__").attr("__dict__");
std::string buffer;
buffer.append("import sys\n");
buffer.append("sys.path.insert(0, '" + python_path() + "')\n");
buffer.append("import migraphx\n");
buffer.append(read_string(file));
py::exec(buffer, scope);
return scope.cast<py::dict>();
}
extern "C" program migraphx_load_py(const std::string& filename)
{
py::scoped_interpreter guard{};
py::dict vars = run_file(filename);
auto it = std::find_if(vars.begin(), vars.end(), [](const auto& p) {
return py::isinstance<migraphx::program>(p.second);
});
if(it == vars.end())
MIGRAPHX_THROW("No program variable found");
return it->second.cast<migraphx::program>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static std::vector<fs::path> find_available_python_versions()
{
std::vector<fs::path> result;
auto path = dynamic_loader::path(&load_py).parent_path();
for(const auto& entry : fs::directory_iterator{path})
{
auto p = entry.path();
if(not fs::is_regular_file(p))
continue;
if(not contains(p.stem().string(), "migraphx_py_"))
continue;
result.push_back(p);
}
std::sort(result.begin(), result.end(), std::greater<>{});
return result;
}
static dynamic_loader load_py_lib()
{
auto libs = find_available_python_versions();
for(const auto& lib : libs)
{
auto result = dynamic_loader::try_load(lib);
if(result.has_value())
return *result;
}
MIGRAPHX_THROW("Cant find a viable version of python");
}
static dynamic_loader py_lib()
{
static dynamic_loader lib = load_py_lib();
return lib;
}
program load_py(const std::string& filename)
{
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
......
......@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
// dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -48,7 +48,7 @@ include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS})
......@@ -89,7 +89,7 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT WIN32)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif()
......
......@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
if(ins.name() == "multibroadcast")
if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
......
......@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name;
auto path = name;
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
......
......@@ -32,6 +32,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <functional>
namespace migraphx {
......@@ -68,12 +69,6 @@ struct compiler_replace
}
};
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile =
std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op =
......
......@@ -29,6 +29,7 @@
#include <migraphx/gpu/config.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/tuning_config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -36,16 +37,19 @@ struct module;
namespace gpu {
MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m);
MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx,
module m,
const std::vector<instruction_ref>& inputs);
const std::vector<instruction_ref>& inputs,
const value& solution);
MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
instruction_ref ins,
code_object_op co,
const std::vector<instruction_ref>& inputs);
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(module m,
const std::vector<shape>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#define MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
......@@ -36,11 +36,12 @@ struct mlir_compiler : compiler<mlir_compiler>
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
compiler_replace
compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const
{
auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod, ins->inputs()));
return insert(compile_mlir(ctx, *smod, ins->inputs(), solution));
}
compiler_replace insert(code_object_op co) const
......@@ -50,6 +51,16 @@ struct mlir_compiler : compiler<mlir_compiler>
m.replace_instruction(ins, mlir);
}};
}
optional<tuning_config>
get_tuning_config(context&, instruction_ref ins, const operation&, bool exhaustive) const
{
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(*smod, shapes);
}
};
} // namespace gpu
......
......@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.virtual_inputs = reduce_dims(normalize_permutation(inputs));
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
......
......@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
return s.with_lens(lens);
}
template <class T>
......@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
return s.with_lens(lens);
}
template <class ReduceLens>
......@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs));
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
......
......@@ -50,8 +50,10 @@
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <deque>
......@@ -134,6 +136,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy);
using mlir_tuning_space = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningSpace,
mlirRockTuningSpaceDestroy);
using mlir_tuning_param = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningParam,
mlirRockTuningParamDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
......@@ -554,14 +560,7 @@ struct mlir_program
static std::string get_symbol_name(const module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "convolution" or ins->name() == "dot")
{
return "mlir_" + ins->name();
}
}
return "main";
return "mlir_" + gen::generate_name_from_ops(m);
}
void parse(const module& m)
......@@ -616,18 +615,30 @@ struct mlir_program
}
}
code_object_op compile() MIGRAPHX_TIDY_CONST
void run_high_level_pipeline() MIGRAPHX_TIDY_CONST
{
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
}
// 2nd pipeline to call
get_module_tuned();
void run_backend_pipeline() MIGRAPHX_TIDY_CONST
{
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
}
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline();
if(solution.is_null())
get_module_tuned();
else
set_tuning(solution);
// 2nd pipeline to call
run_backend_pipeline();
code_object_op op{};
op.symbol_name = sym_name;
......@@ -658,6 +669,33 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program");
}
void set_tuning(const value& v)
{
auto str = v.to<std::string>();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std::vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data()))
MIGRAPHX_THROW("Failed setting tuning key: " + str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())};
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get())))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())});
}
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())};
return tc;
}
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be
......@@ -749,14 +787,14 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op);
}
void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
{
auto names = m.get_parameter_names();
std::sort(names.begin(), names.end());
for(auto i : range(names.size()))
{
const auto& name = names[i];
const auto& input = inputs[i]->get_shape();
const auto& input = inputs[i];
auto param = m.get_parameter(name);
if(input.standard())
continue;
......@@ -794,9 +832,12 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
}
}
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs)
code_object_op compile_mlir(const context&,
module m,
const std::vector<instruction_ref>& inputs,
const value& solution)
{
adjust_param_shapes(m, inputs);
adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace)
......@@ -808,8 +849,9 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile();
co.output = m.get_output_shapes().front();
auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
return co;
}
......@@ -829,6 +871,16 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs);
}
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.find_target();
mp.parse(m);
return mp.get_tuning_config();
}
#else
std::string dump_mlir(const module&) { return {}; }
......@@ -840,11 +892,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&)
code_object_op
compile_mlir(const context&, module, const std::vector<instruction_ref>&, const value&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref
// cppcheck-suppress funcArgNamesDifferent
......@@ -854,6 +906,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end();
}
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; }
// NOLINTEND(performance-unnecessary-value-param)
#endif
} // namespace gpu
......
......@@ -75,7 +75,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifdef _WIN32
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
......@@ -138,7 +138,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
#ifdef _WIN32
#ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment