"docs/vscode:/vscode.git/clone" did not exist on "5822ede66ee834f406d475ce28a1d8f666458a48"
Commit 606ed5e8 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'rand_uniform' into multinomial_parse_merge_random

parents c27d3b62 476ed17c
...@@ -26,6 +26,7 @@ add_library(migraphx_c ...@@ -26,6 +26,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# migraphx_c is stable API interface library. SO version of this should be # migraphx_c is stable API interface library. SO version of this should be
# bumped when binary compatibility is broken. # bumped when binary compatibility is broken.
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
...@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output, ...@@ -899,7 +899,7 @@ migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
extern "C" migraphx_status extern "C" migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions, migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr, const const_migraphx_dynamic_dimension_t* ptr,
size_t size) size_t size)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions ...@@ -1432,7 +1432,7 @@ extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions
} }
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr, const const_migraphx_instruction_t* ptr,
size_t size) size_t size)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
......
This diff is collapsed.
...@@ -79,7 +79,8 @@ def dynamic_dimension(h): ...@@ -79,7 +79,8 @@ def dynamic_dimension(h):
def dynamic_dimensions(h): def dynamic_dimensions(h):
h.constructor( h.constructor(
'create', 'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'), api.params(ptr='const const_migraphx_dynamic_dimension_t*',
size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>') fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t') h.method('size', returns='size_t')
h.method('get', h.method('get',
...@@ -215,7 +216,7 @@ def instruction(h): ...@@ -215,7 +216,7 @@ def instruction(h):
def instructions(h): def instructions(h):
h.constructor( h.constructor(
'create', 'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'), api.params(ptr='const const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>') fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
......
...@@ -32,18 +32,20 @@ add_executable(driver ...@@ -32,18 +32,20 @@ add_executable(driver
marker_roctx.cpp marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility if(NOT WIN32)
add_custom_command( # Copy driver for backwards compatibility (Linux only)
TARGET driver add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver> $<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -338,11 +338,22 @@ struct argument_parser ...@@ -338,11 +338,22 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto file_exist() MIGRAPHX_DRIVER_STATIC auto file_exist()
{ {
return validate([](auto&, auto&, auto& params) { return validate([](auto&, auto&, const auto& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("No argument passed."); throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back())) if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back()); throw std::runtime_error("Path does not exist: " + params.back());
});
}
MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
{
return validate([=](auto&, auto&, const auto& params) {
auto invalid_param = std::find_if(
params.begin(), params.end(), [&](const auto& p) { return names.count(p) == 0; });
if(invalid_param != params.end())
throw std::runtime_error("Invalid argument: " + *invalid_param +
". Valid arguments are {" + to_string_range(names) + "}");
}); });
} }
...@@ -570,8 +581,7 @@ struct argument_parser ...@@ -570,8 +581,7 @@ struct argument_parser
continue; continue;
if(flag[0] != '-') if(flag[0] != '-')
continue; continue;
auto d = std::ptrdiff_t d = levenshtein_distance(flag, input);
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance) if(d < result.distance)
result = result_t{&arg, flag, input, d}; result = result_t{&arg, flag, input, d};
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -81,6 +82,7 @@ struct loader ...@@ -81,6 +82,7 @@ struct loader
{"--model"}, {"--model"},
ap.help("Load model"), ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"), ap.type("resnet50|inceptionv3|alexnet"),
ap.matches({"resnet50", "inceptionv3", "alexnet"}),
ap.group("input")); ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
...@@ -241,6 +243,20 @@ struct loader ...@@ -241,6 +243,20 @@ struct loader
return options; return options;
} }
static std::string get_file_type(const std::string& file)
{
if(ends_with(file, ".onnx"))
return "onnx";
else if(ends_with(file, ".pb"))
return "tf";
else if(ends_with(file, ".json"))
return "json";
else if(ends_with(file, ".py"))
return "py";
else
return "migraphx";
}
program load() program load()
{ {
program p; program p;
...@@ -248,14 +264,7 @@ struct loader ...@@ -248,14 +264,7 @@ struct loader
{ {
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) file_type = get_file_type(file);
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -272,6 +281,10 @@ struct loader ...@@ -272,6 +281,10 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
...@@ -757,7 +770,7 @@ struct main_command ...@@ -757,7 +770,7 @@ struct main_command
{ {
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl; << "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl; std::cout << get_command_help("Available commands:");
} }
else else
{ {
......
...@@ -48,7 +48,7 @@ struct dynamic_loader_impl ...@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), : handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}), manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t)) temp(std::move(t))
{ {
...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p)) dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{ {
} }
......
...@@ -35,6 +35,8 @@ ...@@ -35,6 +35,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS)
static bool try_compute_shape(instruction_ref ins, static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mods) const std::vector<module_ref>& mods)
...@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(not try_compute_shape(output, input_shapes, mods)) if(not try_compute_shape(output, input_shapes, output->module_inputs()))
{ {
return false; return false;
} }
} }
} }
catch(const std::exception& e)
{
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Exception: " << e.what() << std::endl;
}
return false;
}
catch(...) catch(...)
{ {
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Unknown exception" << std::endl;
}
return false; return false;
} }
...@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{ {
if(arg->name() != op_name) if(arg->name() != op_name)
continue; continue;
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "eliminate_contiguous: ";
m.debug_print(ins);
}
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args)) if(try_compute_shape(ins, new_args, mod_args))
......
...@@ -41,7 +41,7 @@ static literal get_scalar(instruction_ref ins) ...@@ -41,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(s.elements() != 1 && not(s.scalar())) if(s.elements() != 1 and not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
......
...@@ -52,7 +52,7 @@ struct fused_reduce ...@@ -52,7 +52,7 @@ struct fused_reduce
{ {
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
auto* sm = mods.front(); const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1) if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported"); MIGRAPHX_THROW("Only one output supported");
auto names = sm->get_parameter_names(); auto names = sm->get_parameter_names();
...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm, ...@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
} }
static std::vector<instruction_ref> static std::vector<instruction_ref>
find_inputs(module_ref sm, find_inputs(const_module_ref sm,
const module& parent, const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat ...@@ -90,6 +92,42 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return std::ptrdiff_t{1} + std::min({x1, x2, x3}); return std::ptrdiff_t{1} + std::min({x1, x2, x3});
} }
inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
{
const size_t l1 = s1.length();
const size_t l2 = s2.length();
if(l1 < l2)
levenshtein_distance(s2, s1);
std::vector<size_t> d(l2 + 1);
std::iota(d.begin(), d.end(), 0);
for(size_t i = 1; i <= l1; i++)
{
size_t prev_cost = d[0];
d[0] = i;
for(size_t j = 1; j <= l2; j++)
{
if(s1[i - 1] == s2[j - 1])
{
d[j] = prev_cost;
}
else
{
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
size_t cost_substitute = prev_cost;
prev_cost = d[j];
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
}
}
}
return d[l2];
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -90,7 +90,17 @@ struct param ...@@ -90,7 +90,17 @@ struct param
struct returns struct returns
{ {
std::string name() const { return "@return"; } std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,21 +34,37 @@ ...@@ -34,21 +34,37 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Check that deduced type is incrementable, dereferencable, and comparable
template <class, class = void>
struct is_iterator
{
};
template <class T>
struct is_iterator<T,
std::void_t<decltype(++std::declval<T&>()),
decltype(*std::declval<T&>()),
decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
{
};
template <class Iterator>
struct check_shapes struct check_shapes
{ {
const shape* begin; static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
const shape* end; Iterator begin;
Iterator end;
std::string name; std::string name;
bool dynamic_allowed; bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false) check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d) : begin(b), end(e), name(n), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false) check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d) : begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
...@@ -56,7 +72,7 @@ struct check_shapes ...@@ -56,7 +72,7 @@ struct check_shapes
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
...@@ -81,8 +97,6 @@ struct check_shapes ...@@ -81,8 +97,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return 0; return 0;
assert(begin != nullptr);
assert(end != nullptr);
return end - begin; return end - begin;
} }
...@@ -131,8 +145,6 @@ struct check_shapes ...@@ -131,8 +145,6 @@ struct check_shapes
*/ */
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() != n) if(begin->max_lens().size() != n)
...@@ -148,8 +160,6 @@ struct check_shapes ...@@ -148,8 +160,6 @@ struct check_shapes
*/ */
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() > n) if(begin->max_lens().size() > n)
...@@ -166,8 +176,6 @@ struct check_shapes ...@@ -166,8 +176,6 @@ struct check_shapes
*/ */
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() < n) if(begin->max_lens().size() < n)
...@@ -330,8 +338,6 @@ struct check_shapes ...@@ -330,8 +338,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
...@@ -341,8 +347,6 @@ struct check_shapes ...@@ -341,8 +347,6 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
...@@ -351,17 +355,13 @@ struct check_shapes ...@@ -351,17 +355,13 @@ struct check_shapes
{ {
if(begin == end) if(begin == end)
return false; return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p); return std::any_of(begin, end, p);
} }
const shape* get(long i) const Iterator get(long i) const
{ {
if(i >= size()) if(i >= size())
MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds"); MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0) if(i < 0)
return end - i; return end - i;
return begin + i; return begin + i;
...@@ -394,6 +394,11 @@ struct check_shapes ...@@ -394,6 +394,11 @@ struct check_shapes
} }
}; };
// Deduction guide for std::vector constructor
template <class Op>
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
-> check_shapes<std::vector<shape>::const_iterator>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/optional.hpp>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader ...@@ -43,6 +44,9 @@ struct MIGRAPHX_EXPORT dynamic_loader
return path(reinterpret_cast<void*>(address)); return path(reinterpret_cast<void*>(address));
} }
static fs::path path(void* address); static fs::path path(void* address);
static optional<dynamic_loader> try_load(const fs::path& p);
dynamic_loader() = default; dynamic_loader() = default;
dynamic_loader(const fs::path& p); dynamic_loader(const fs::path& p);
......
...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module ...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
*/
module& sort(); module& sort();
/* Any instruction "X" can have module arguments and those modules inside them can use any other
* instruction "Y" from predecessor modules of the instruction "X". Such instruction "Y" inside
* module args are not listed as input instructions to "X". But those instructions "Y" must be
* evaluted before the instruction "X" can. Therefore such "Y" instructions are considered
* implicit dependency to "X".
*/
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,6 +43,36 @@ struct select_dependent_type ...@@ -42,6 +43,36 @@ struct select_dependent_type
template <class T, class... Ts> template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
/**
* Used to normalize variable input axes at model runtime.
* Example: the axes inputs of the slice operator.
*
* \param axes the axes to normalize
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix = "");
/**
* Used to normalize variable input axes at model runtime.
* Example: the starts and ends inputs of the slice operator.
*
* \param indices the indices to normalize
* \param axes which axes the indices apply over
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix = "");
MIGRAPHX_EXPORT MIGRAPHX_EXPORT
bool normalize_attributes(operation& op, const shape& input_shape); bool normalize_attributes(operation& op, const shape& input_shape);
......
...@@ -82,7 +82,7 @@ struct convolution ...@@ -82,7 +82,7 @@ struct convolution
const auto input_ndim = inputs[0].ndim(); const auto input_ndim = inputs[0].ndim();
const auto padding_size = padding.size(); const auto padding_size = padding.size();
if(input_ndim != padding_size / 2 + 2 && input_ndim != padding_size + 2) if(input_ndim != padding_size / 2 + 2 and input_ndim != padding_size + 2)
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
......
...@@ -71,7 +71,7 @@ struct if_op ...@@ -71,7 +71,7 @@ struct if_op
std::unordered_map<std::string, argument> params; std::unordered_map<std::string, argument> params;
std::set<std::string> pnames; std::set<std::string> pnames;
for(const auto& smod : mods) for(const_module_ref smod : mods)
{ {
auto names = smod->get_parameter_names(); auto names = smod->get_parameter_names();
pnames.insert(names.begin(), names.end()); pnames.insert(names.begin(), names.end());
......
...@@ -59,9 +59,9 @@ struct loop ...@@ -59,9 +59,9 @@ struct loop
MIGRAPHX_THROW("LOOP: operator should have one submodule."); MIGRAPHX_THROW("LOOP: operator should have one submodule.");
} }
const auto& mod = mods.front(); const_module_ref mod = mods.front();
auto mod_out_shapes = mod->get_output_shapes(); auto mod_out_shapes = mod->get_output_shapes();
auto dep_param_num = inputs.size() - 2; auto dep_param_num = inputs.size() - 2;
// first item of the mod output shapes is condition used in loop, // first item of the mod output shapes is condition used in loop,
// which is not needed to compute output shape // which is not needed to compute output shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment