Commit 30c8ff61 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

integrating auto_cont changes

parents 4ff8a292 7aee6388
...@@ -26,7 +26,7 @@ def rocmtestnode(Map conf) { ...@@ -26,7 +26,7 @@ def rocmtestnode(Map conf) {
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On ${flags} .. cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT ${flags} ..
git diff git diff
git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1) git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1)
make -j\$(nproc) generate VERBOSE=1 make -j\$(nproc) generate VERBOSE=1
...@@ -90,7 +90,7 @@ def rocmnodename(name) { ...@@ -90,7 +90,7 @@ def rocmnodename(name) {
} else if(name == "mi100+") { } else if(name == "mi100+") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a) && !vm"; node_name = "${rocmtest_name} && (gfx908 || gfx90a) && !vm";
} else if(name == "cdna") { } else if(name == "cdna") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega) && !vm"; node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega20) && !vm";
} else if(name == "nogpu") { } else if(name == "nogpu") {
node_name = "${rocmtest_name} && nogpu"; node_name = "${rocmtest_name} && nogpu";
} }
......
...@@ -24,11 +24,7 @@ ...@@ -24,11 +24,7 @@
find_program(EMBED_LD ld) find_program(EMBED_LD ld)
find_program(EMBED_OBJCOPY objcopy) find_program(EMBED_OBJCOPY objcopy)
if(LINUX) option(EMBED_USE_LD "Use ld to embed data files" OFF)
option(EMBED_USE_LD "Use ld to embed data files" ON)
else()
option(EMBED_USE_LD "Use ld to embed data files" OFF)
endif()
function(wrap_string) function(wrap_string)
set(options) set(options)
...@@ -60,8 +56,8 @@ endfunction() ...@@ -60,8 +56,8 @@ endfunction()
function(generate_embed_source EMBED_NAME) function(generate_embed_source EMBED_NAME)
set(options) set(options)
set(oneValueArgs SRC HEADER) set(oneValueArgs SRC HEADER RELATIVE)
set(multiValueArgs OBJECTS SYMBOLS) set(multiValueArgs OBJECTS SYMBOLS FILES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
...@@ -78,6 +74,8 @@ function(generate_embed_source EMBED_NAME) ...@@ -78,6 +74,8 @@ function(generate_embed_source EMBED_NAME)
foreach(idx RANGE ${LEN}) foreach(idx RANGE ${LEN})
list(GET PARSE_SYMBOLS ${idx} SYMBOL) list(GET PARSE_SYMBOLS ${idx} SYMBOL)
list(GET PARSE_OBJECTS ${idx} OBJECT) list(GET PARSE_OBJECTS ${idx} OBJECT)
list(GET PARSE_FILES ${idx} FILE)
set(START_SYMBOL "_binary_${SYMBOL}_start") set(START_SYMBOL "_binary_${SYMBOL}_start")
set(END_SYMBOL "_binary_${SYMBOL}_end") set(END_SYMBOL "_binary_${SYMBOL}_end")
if(EMBED_USE_LD) if(EMBED_USE_LD)
...@@ -92,9 +90,11 @@ function(generate_embed_source EMBED_NAME) ...@@ -92,9 +90,11 @@ function(generate_embed_source EMBED_NAME)
") ")
endif() endif()
# TODO: Should use NAME_WLE if(PARSE_RELATIVE)
get_filename_component(BASE_NAME "${OBJECT}" NAME) file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${FILE}")
string(REGEX REPLACE ".[A-Za-z0-9_]+$" "" BASE_NAME ${BASE_NAME}) else()
get_filename_component(BASE_NAME "${FILE}" NAME)
endif()
string(APPEND INIT_KERNELS " string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} },
...@@ -162,6 +162,11 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE) ...@@ -162,6 +162,11 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
endfunction() endfunction()
function(add_embed_library EMBED_NAME) function(add_embed_library EMBED_NAME)
set(options)
set(oneValueArgs RELATIVE)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed) file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed)
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
...@@ -171,15 +176,26 @@ function(add_embed_library EMBED_NAME) ...@@ -171,15 +176,26 @@ function(add_embed_library EMBED_NAME)
set(OUTPUT_FILES) set(OUTPUT_FILES)
set(SYMBOLS) set(SYMBOLS)
message(STATUS "Embedding files") message(STATUS "Embedding files")
foreach(FILE ${ARGN}) foreach(FILE ${PARSE_UNPARSED_ARGUMENTS})
embed_file(OUTPUT_FILE OUTPUT_SYMBOL ${FILE}) embed_file(OUTPUT_FILE OUTPUT_SYMBOL ${FILE})
list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) list(APPEND OUTPUT_FILES ${OUTPUT_FILE})
list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) list(APPEND SYMBOLS ${OUTPUT_SYMBOL})
endforeach() endforeach()
message(STATUS "Generating embedding library ${EMBED_NAME}") message(STATUS "Generating embedding library ${EMBED_NAME}")
generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS}) generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE} FILES ${PARSE_UNPARSED_ARGUMENTS})
add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}")
target_include_directories(${EMBED_NAME} PUBLIC "${EMBED_DIR}/include") set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME})
target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations) add_library(${INTERNAL_EMBED_LIB} OBJECT "${SRC_FILE}")
set_target_properties(${EMBED_NAME} PROPERTIES POSITION_INDEPENDENT_CODE On) target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include")
target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(${EMBED_NAME} INTERFACE)
if(EMBED_USE_LD)
target_sources(${EMBED_NAME} INTERFACE ${OUTPUT_FILES})
else()
target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES})
endif()
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
endfunction() endfunction()
...@@ -38,12 +38,22 @@ macro(find_python version) ...@@ -38,12 +38,22 @@ macro(find_python version)
find_program(PYTHON_CONFIG_${version} python${version}-config) find_program(PYTHON_CONFIG_${version} python${version}-config)
if(EXISTS ${PYTHON_CONFIG_${version}}) if(EXISTS ${PYTHON_CONFIG_${version}})
py_exec(COMMAND ${PYTHON_CONFIG_${version}} --includes OUTPUT_VARIABLE _python_include_args) py_exec(COMMAND ${PYTHON_CONFIG_${version}} --includes OUTPUT_VARIABLE _python_include_args)
execute_process(COMMAND ${PYTHON_CONFIG_${version}} --ldflags --embed OUTPUT_VARIABLE _python_ldflags_args RESULT_VARIABLE _python_ldflags_result)
if(NOT _python_ldflags_result EQUAL 0)
py_exec(COMMAND ${PYTHON_CONFIG_${version}} --ldflags OUTPUT_VARIABLE _python_ldflags_args)
endif()
separate_arguments(_python_includes UNIX_COMMAND "${_python_include_args}") separate_arguments(_python_includes UNIX_COMMAND "${_python_include_args}")
separate_arguments(_python_ldflags UNIX_COMMAND "${_python_ldflags_args}")
string(REPLACE "-I" "" _python_includes "${_python_includes}") string(REPLACE "-I" "" _python_includes "${_python_includes}")
add_library(python${version}::headers INTERFACE IMPORTED GLOBAL) add_library(python${version}::headers INTERFACE IMPORTED GLOBAL)
set_target_properties(python${version}::headers PROPERTIES set_target_properties(python${version}::headers PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${_python_includes}" INTERFACE_INCLUDE_DIRECTORIES "${_python_includes}"
) )
add_library(python${version}::runtime INTERFACE IMPORTED GLOBAL)
set_target_properties(python${version}::runtime PROPERTIES
INTERFACE_LINK_OPTIONS "${_python_ldflags}"
INTERFACE_LINK_LIBRARIES python${version}::headers
)
py_exec(COMMAND ${PYTHON_CONFIG_${version}} --prefix OUTPUT_VARIABLE _python_prefix) py_exec(COMMAND ${PYTHON_CONFIG_${version}} --prefix OUTPUT_VARIABLE _python_prefix)
string(STRIP "${_python_prefix}" _python_prefix) string(STRIP "${_python_prefix}" _python_prefix)
set(PYTHON_${version}_EXECUTABLE "${_python_prefix}/bin/python${version}" CACHE PATH "") set(PYTHON_${version}_EXECUTABLE "${_python_prefix}/bin/python${version}" CACHE PATH "")
......
...@@ -82,6 +82,10 @@ Print out program in text format. ...@@ -82,6 +82,10 @@ Print out program in text format.
Print out program in binary format. Print out program in binary format.
.. option:: --py
Print out program using python API.
.. option:: --output, -o [std::string] .. option:: --output, -o [std::string]
Output to file. Output to file.
......
...@@ -26,6 +26,7 @@ add_library(migraphx_c ...@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# migraphx_c is stable API interface library. SO version of this should be # migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken. # bumped when binary compatibility is broken.
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
......
This diff is collapsed.
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
...@@ -59,18 +60,27 @@ void auto_contiguous::apply(module& m) const ...@@ -59,18 +60,27 @@ void auto_contiguous::apply(module& m) const
auto last = std::prev(m.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "layout") if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name()))
continue; continue;
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.dynamic() and not s.standard() and s.elements() != 0) if(s.dynamic())
continue;
if(s.type() == shape::tuple_type)
continue;
if(s.standard() and ins->name() == "@literal")
continue;
if(s.scalar() and not contains(ins->name(), "broadcast"))
{ {
continue;
}
if(ins->name() == "pointwise")
std::cout << "HERE" << std::endl;
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -32,18 +32,20 @@ add_executable(driver ...@@ -32,18 +32,20 @@ add_executable(driver
marker_roctx.cpp marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility if(NOT WIN32)
add_custom_command( # Copy driver for backwards compatibility (Linux only)
add_custom_command(
TARGET driver TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver> $<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -241,6 +242,20 @@ struct loader ...@@ -241,6 +242,20 @@ struct loader
return options; return options;
} }
static std::string get_file_type(const std::string& file)
{
if(ends_with(file, ".onnx"))
return "onnx";
else if(ends_with(file, ".pb"))
return "tf";
else if(ends_with(file, ".json"))
return "json";
else if(ends_with(file, ".py"))
return "py";
else
return "migraphx";
}
program load() program load()
{ {
program p; program p;
...@@ -248,14 +263,7 @@ struct loader ...@@ -248,14 +263,7 @@ struct loader
{ {
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) file_type = get_file_type(file);
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -272,6 +280,10 @@ struct loader ...@@ -272,6 +280,10 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
......
...@@ -48,7 +48,7 @@ struct dynamic_loader_impl ...@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), : handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}), manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t)) temp(std::move(t))
{ {
...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p)) dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{ {
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <type_traits>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -161,6 +162,18 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -161,6 +162,18 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
} }
static void remove_contiguous_noops(const std::string& op_name, module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != op_name)
continue;
if(ins->inputs().front()->get_shape() != ins->get_shape())
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void eliminate_contiguous::apply(module& m) const void eliminate_contiguous::apply(module& m) const
{ {
// Skip contiguous from splits first // Skip contiguous from splits first
...@@ -170,6 +183,7 @@ void eliminate_contiguous::apply(module& m) const ...@@ -170,6 +183,7 @@ void eliminate_contiguous::apply(module& m) const
return (ins->inputs().front()->outputs().size() == 1); return (ins->inputs().front()->outputs().size() == 1);
}); });
remove_contiguous(op_name, m, [](auto) { return true; }); remove_contiguous(op_name, m, [](auto) { return true; });
remove_contiguous_noops(op_name, m);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -108,91 +108,42 @@ void remove_layout(module& m) ...@@ -108,91 +108,42 @@ void remove_layout(module& m)
// return convs; // return convs;
// } // }
// void remove_layout(module& m, const std::vector<instruction_ref>& convs) void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// { {
// if(convs.size() < 2) return; for(auto ins : iterator_for(m))
// m.debug_print(); {
if(ins->name() != "gpu::precompile_op")
// for(auto i = 0; i < convs.size() - 1; i++) continue;
// {
// bool reached_start = false;
// for(auto ins : iterator_for(m))
// {
// if(ins == convs[i])
// reached_start = true;
// if(reached_start)
// {
// if(ins->name() == "gpu::pooling")
// break;
// if(ins == convs[i + 1])
// {
// m.debug_print(convs[i]->outputs().front());
// m.debug_print(convs[i]->outputs().front()->outputs().front());
// m.replace_instruction(convs[i]->outputs().front(),
// convs[i]->outputs().front()->outputs().front()); std::cout << "HERE" <<
// std::endl; m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front());
// // m.debug_print(convs[i]->outputs().front()->outputs().front());
// std::cout << std::endl;
// m.debug_print(convs[i]->inputs());
// std::cout << std::endl;
// m.debug_print(convs[i + 1]->inputs());
// for(auto j = 0; j < convs[i + 1]->inputs().size(); j++)
// {
// if(convs[i]->inputs()[j] == convs[i + 1]->inputs()[j])
// {
// std::cout << "HERE2" << std::endl;
// continue;
// }
// m.replace_instruction(convs[i + 1]->inputs()[j], convs[i +
// 1]->inputs()[j]->inputs().front()); m.debug_print(convs[i+1]);
// }
// break;
// }
// }
// }
// }
// }
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "gpu::precompile_op")
// continue;
// auto precompile_op = ins->get_operator(); auto precompile_op = ins->get_operator();
// auto val = precompile_op.to_value(); auto val = precompile_op.to_value();
// if(val["op"].at("name").to<std::string>() != "layout") if(val["op"].at("name").to<std::string>() != "layout")
// { {
// // std::cout << val["op"].at("name").to<std::string>() << std::endl; // std::cout << val["op"].at("name").to<std::string>() << std::endl;
// continue; continue;
// } }
// m.debug_print(ins); // m.debug_print(ins);
// if(ins->get_shape() != ins->inputs().front()->get_shape()) if(ins->get_shape() != ins->inputs().front()->get_shape())
// { {
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() << // std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() <<
// std::endl; continue; // std::endl;
// } continue;
// if(contains(output_layouts, ins)) }
// continue; if(contains(output_layouts, ins))
continue;
// m.replace_instruction(ins, ins->inputs().front()); m.replace_instruction(ins, ins->inputs().front());
// } }
// } }
void eliminate_layout::apply(module_pass_manager& mpm) const void eliminate_layout::apply(module_pass_manager& mpm) const
{ {
// std::unordered_set<instruction_ref> output_layouts = std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module()); remove_layout(mpm.get_module(), preserve_output_layout(mpm.get_module());
remove_layout(mpm.get_module(), output_layouts);
// find_convs(mpm.get_module())); // find_convs(mpm.get_module()));
remove_layout(mpm.get_module()); // remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
} }
......
...@@ -90,7 +90,17 @@ struct param ...@@ -90,7 +90,17 @@ struct param
struct returns struct returns
{ {
std::string name() const { return "@return"; } std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/optional.hpp>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader ...@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader
return path(reinterpret_cast<void*>(address)); return path(reinterpret_cast<void*>(address));
} }
static fs::path path(void* address); static fs::path path(void* address);
static optional<dynamic_loader> try_load(const fs::path& p);
dynamic_loader() = default; dynamic_loader() = default;
dynamic_loader(const fs::path& p); dynamic_loader(const fs::path& p);
......
...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module ...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
*/
module& sort(); module& sort();
/* Any instruction "X" can have module arguments and those modules inside them can use any other
* instruction "Y" from predecessor modules of the instruction "X". Such instruction "Y" inside
* module args are not listed as input instructions to "X". But those instructions "Y" must be
* evaluted before the instruction "X" can. Therefore such "Y" instructions are considered
* implicit dependency to "X".
*/
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
......
...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r) ...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return") assert(ins->name() == "@return" or ins->name().front() != '@');
continue;
assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
} }
...@@ -122,10 +119,6 @@ bool instruction::valid() const ...@@ -122,10 +119,6 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@return")
{
computed = {};
}
else else
{ {
try try
...@@ -145,6 +138,7 @@ bool instruction::valid() const ...@@ -145,6 +138,7 @@ bool instruction::valid() const
} }
shape instruction::get_shape() const { return result; } shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const const literal& instruction::get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
......
...@@ -149,10 +149,10 @@ void layout_nhwc::apply(module_pass_manager& mpm) const ...@@ -149,10 +149,10 @@ void layout_nhwc::apply(module_pass_manager& mpm) const
// std::cout << "after layout" << std::endl; // std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print(); // mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous) // if(not this->skip_elim_contiguous)
mpm.run_pass(eliminate_contiguous{"contiguous"}); // mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{}); // mpm.run_pass(dead_code_elimination{});
mpm.run_pass(auto_contiguous{}); // mpm.run_pass(auto_contiguous{});
mpm.run_pass(dead_code_elimination{}); // mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts); // remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{}); // mpm.run_pass(dead_code_elimination{});
} }
......
...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->push_back({builtin::returns{}, {}, std::move(args)}); shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
...@@ -1011,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const ...@@ -1011,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
module& module::sort() module& module::sort()
{ {
auto implicit_deps = calc_implicit_deps();
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) auto ins_inputs = ins->inputs();
if(implicit_deps.find(ins) != implicit_deps.end())
{
auto ins_implict_inputs = implicit_deps.at(ins);
ins_inputs.insert(
ins_inputs.end(), ins_implict_inputs.begin(), ins_implict_inputs.end());
}
for(auto child : ins_inputs)
{ {
if(not contains(this->impl->instructions, child)) if(not contains(this->impl->instructions, child))
{ {
......
...@@ -40,13 +40,14 @@ ...@@ -40,13 +40,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp> #include <migraphx/supported_segments.hpp>
#include <iostream> #include <iostream>
#include <queue>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
#include <cassert> #include <cassert>
...@@ -1191,11 +1192,19 @@ void program::remove_unused_modules() ...@@ -1191,11 +1192,19 @@ void program::remove_unused_modules()
program& program::sort() program& program::sort()
{ {
for(auto& pp : this->impl->modules) std::queue<migraphx::module_ref> mqueue;
mqueue.push(get_main_module());
while(not mqueue.empty())
{ {
pp.second.sort(); module_ref current_mod = mqueue.front();
current_mod->sort();
mqueue.pop();
auto child_mods = current_mod->get_sub_modules(true);
for(auto& sub_mod : child_mods)
{
mqueue.push(sub_mod);
}
} }
return *this; return *this;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment