Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{ {
params.push_back({pname, "T" + pname}); params.push_back({pname, "T" + pname});
...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op, ...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else if(with_char(::isdigit)(key[0])) else if(with_char(::isdigit)(key[0]))
{ {
auto i = std::stoul(key); auto i = std::stoul(key);
if(i >= args.size())
MIGRAPHX_THROW("Invalid argument index: " + key);
return args.at(i); return args.at(i);
} }
else if(v.contains(key)) else if(v.contains(key))
...@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m, ...@@ -201,8 +208,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
f.set_name(name).set_types(m).set_body( f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string { m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" + {
ins->get_literal().to_string() + ")"; std::string string_literal;
ins->get_literal().visit([&](auto v) {
assert(v.size() == 1);
auto x = v.front();
if(std::isinf(x))
{
string_literal = "__builtin_huge_val()";
if(x < 0)
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(x))
string_literal = "__builtin_nan()";
else
string_literal = ins->get_literal().to_string();
});
return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")";
}
auto s = g(ins, names); auto s = g(ins, names);
if(impl->fresult) if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')'; return impl->fresult(ins->get_shape()) + '(' + s + ')';
...@@ -238,6 +261,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f) ...@@ -238,6 +261,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '('; char delim = '(';
if(f.params.empty())
impl->fs << delim;
for(auto&& p : f.params) for(auto&& p : f.params)
{ {
impl->fs << delim << p.type << " " << p.name; impl->fs << delim << p.type << " " << p.name;
......
...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined()) not i->is_undefined())
continue; continue;
......
...@@ -32,18 +32,20 @@ add_executable(driver ...@@ -32,18 +32,20 @@ add_executable(driver
marker_roctx.cpp marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility if(NOT WIN32)
add_custom_command( # Copy driver for backwards compatibility (Linux only)
TARGET driver add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver> $<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -342,7 +342,19 @@ struct argument_parser ...@@ -342,7 +342,19 @@ struct argument_parser
if(params.empty()) if(params.empty())
throw std::runtime_error("No argument passed."); throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back())) if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back()); throw std::runtime_error("Path does not exist: " + params.back());
});
}
MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
{
return validate([=](auto&, auto&, auto& params) {
for(const auto& p : params)
{
if(names.count(p) == 0)
throw std::runtime_error("Invalid argument: " + p + ". Valid arguments are {" +
to_string_range(names) + "}");
}
}); });
} }
...@@ -570,8 +582,7 @@ struct argument_parser ...@@ -570,8 +582,7 @@ struct argument_parser
continue; continue;
if(flag[0] != '-') if(flag[0] != '-')
continue; continue;
auto d = std::ptrdiff_t d = levenshtein_distance(flag, input);
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance) if(d < result.distance)
result = result_t{&arg, flag, input, d}; result = result_t{&arg, flag, input, d};
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/py.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -81,6 +82,7 @@ struct loader ...@@ -81,6 +82,7 @@ struct loader
{"--model"}, {"--model"},
ap.help("Load model"), ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"), ap.type("resnet50|inceptionv3|alexnet"),
ap.matches({"resnet50", "inceptionv3", "alexnet"}),
ap.group("input")); ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
...@@ -241,6 +243,20 @@ struct loader ...@@ -241,6 +243,20 @@ struct loader
return options; return options;
} }
static std::string get_file_type(const std::string& file)
{
if(ends_with(file, ".onnx"))
return "onnx";
else if(ends_with(file, ".pb"))
return "tf";
else if(ends_with(file, ".json"))
return "json";
else if(ends_with(file, ".py"))
return "py";
else
return "migraphx";
}
program load() program load()
{ {
program p; program p;
...@@ -248,14 +264,7 @@ struct loader ...@@ -248,14 +264,7 @@ struct loader
{ {
if(file_type.empty()) if(file_type.empty())
{ {
if(ends_with(file, ".onnx")) file_type = get_file_type(file);
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
else if(ends_with(file, ".json"))
file_type = "json";
else
file_type = "migraphx";
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
...@@ -272,6 +281,10 @@ struct loader ...@@ -272,6 +281,10 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
...@@ -415,7 +428,8 @@ struct compiler ...@@ -415,7 +428,8 @@ struct compiler
program_params parameters; program_params parameters;
compiler_target ct; compiler_target ct;
compile_options co; compile_options co;
precision quantize = precision::fp32; bool to_fp16 = false;
bool to_int8 = false;
std::vector<std::string> fill0; std::vector<std::string> fill0;
std::vector<std::string> fill1; std::vector<std::string> fill1;
...@@ -436,13 +450,8 @@ struct compiler ...@@ -436,13 +450,8 @@ struct compiler
{"--exhaustive-tune"}, {"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"), ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true)); ap.set_value(true));
ap(co.split_single_dyn_dim, ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
{"--split-single-dyn-dim"}, ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap.help("If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
auto params(const program& p) auto params(const program& p)
...@@ -450,20 +459,46 @@ struct compiler ...@@ -450,20 +459,46 @@ struct compiler
return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
} }
auto host_params(const program& p)
{
return parameters.generate(p, ct.get_target(), true, l.batch);
}
program compile() program compile()
{ {
auto p = l.load(); auto p = l.load();
// Dont compile if its already been compiled // Dont compile if its already been compiled
if(p.is_compiled()) if(p.is_compiled())
{
if(ct.target_name == "gpu")
{
if(is_offload_copy_set(p) and not co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
}
else if(co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"fails.\n";
}
}
return p; return p;
}
auto t = ct.get_target(); auto t = ct.get_target();
if(quantize == precision::fp16) if(to_fp16)
{ {
quantize_fp16(p); quantize_fp16(p);
} }
else if(quantize == precision::int8) if(to_int8)
{ {
quantize_int8(p, t, {params(p)}); quantize_int8(p, t, {host_params(p)});
} }
p.compile(t, co); p.compile(t, co);
l.save(p); l.save(p);
...@@ -522,17 +557,23 @@ struct verify : command<verify> ...@@ -522,17 +557,23 @@ struct verify : command<verify>
auto t = c.ct.get_target(); auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch); auto m = c.parameters.generate(p, t, true, c.l.batch);
auto quantize = precision::fp32;
if(c.to_fp16)
quantize = precision::fp16;
if(c.to_int8)
quantize = precision::int8;
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, c.quantize, tolerance); verify_instructions(p, t, c.co, quantize, tolerance);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, c.quantize, m, tolerance); verify_reduced_program(p, t, c.co, quantize, m, tolerance);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
} }
} }
}; };
...@@ -662,6 +703,26 @@ struct onnx : command<onnx> ...@@ -662,6 +703,26 @@ struct onnx : command<onnx>
} }
}; };
struct tf : command<tf>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all tf operators supported by MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_tf_operators())
std::cout << name << std::endl;
}
}
};
struct main_command struct main_command
{ {
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow, static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
...@@ -709,7 +770,7 @@ struct main_command ...@@ -709,7 +770,7 @@ struct main_command
{ {
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl; << "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl; std::cout << get_command_help("Available commands:");
} }
else else
{ {
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -97,6 +99,38 @@ target get_target(bool gpu) ...@@ -97,6 +99,38 @@ target get_target(bool gpu)
return make_target("cpu"); return make_target("cpu");
} }
bool is_offload_copy_set(const program& p)
{
assert(p.is_compiled());
const module* mm = p.get_main_module();
std::vector<std::string> param_names = mm->get_parameter_names();
std::unordered_set<instruction_ref> param_ins;
std::transform(param_names.begin(),
param_names.end(),
std::inserter(param_ins, param_ins.begin()),
[&](const auto& i) { return mm->get_parameter(i); });
for(const auto& i : *mm)
{
if(i.name() == "hip::copy_to_gpu")
{
auto copy_arg = instruction::get_output_alias(i.inputs().front(), true);
param_ins.erase(copy_arg);
}
else if(i.name() == "@return")
{
auto return_args = i.inputs();
for(const auto& j : return_args)
{
auto alias_ins = instruction::get_output_alias(j, true);
if((alias_ins->name() == "@param" && param_ins.erase(alias_ins) == 0) or
(alias_ins->name() != "hip::copy_from_gpu"))
return false;
}
}
}
return param_ins.empty();
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload = ...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
/**
* @brief Checks if MIGraphX program compiled for "GPU" has offload_copy set of not. This is
intended to print a HINT for the users and would not always correctly classify compiled program as
with or without offload_copy in all cases.
* @param p Compiled MIGraphX program for GPU backend
* @return true if program is classified as compiled with "offload_copy" set
*/
bool is_offload_copy_set(const program& p);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -48,7 +48,7 @@ struct dynamic_loader_impl ...@@ -48,7 +48,7 @@ struct dynamic_loader_impl
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), : handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}), manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t)) temp(std::move(t))
{ {
...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +81,18 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p)) dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{ {
} }
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <iterator> #include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue; continue;
if(ins->get_operator().name() == "layout") if(ins->get_operator().name() == "layout")
continue; continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs; std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0; std::size_t i = 0;
for(auto input : ins->inputs()) for(auto input : ins->inputs())
{ {
if(contains(param_map, input)) if(contains(param_map, input))
...@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
} }
} }
// Don't create pointwise module if no inputs are detected
if(pointwise_inputs.empty())
continue;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(), std::transform(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
...@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const ...@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
create_pointwise_modules(mpm); create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
{
return;
}
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value) argument fill_argument(shape s, double value)
{ {
argument result; argument result;
if(s.type() == shape::tuple_type) if(s.type() == shape::tuple_type)
......
...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct adjust_allocation struct MIGRAPHX_EXPORT adjust_allocation
{ {
allocation_model model; allocation_model model;
std::string name() const { return "adjust_allocation"; } std::string name() const { return "adjust_allocation"; }
......
...@@ -90,6 +90,43 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat ...@@ -90,6 +90,43 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return std::ptrdiff_t{1} + std::min({x1, x2, x3}); return std::ptrdiff_t{1} + std::min({x1, x2, x3});
} }
inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
{
const size_t l1 = s1.length();
const size_t l2 = s2.length();
if(l1 < l2)
levenshtein_distance(s2, s1);
std::vector<size_t> d(l2 + 1);
for(size_t j = 1; j <= l2; j++)
d[j] = j;
for(size_t i = 1; i <= l1; i++)
{
size_t prev_cost = d[0];
d[0] = i;
for(size_t j = 1; j <= l2; j++)
{
if(s1[i - 1] == s2[j - 1])
{
d[j] = prev_cost;
}
else
{
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
size_t cost_substitute = prev_cost;
prev_cost = d[j];
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
}
}
}
return d[l2];
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -60,7 +60,7 @@ struct allocation_model ...@@ -60,7 +60,7 @@ struct allocation_model
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct allocation_model struct MIGRAPHX_EXPORT allocation_model
{ {
// //
std::string name() const; std::string name() const;
...@@ -96,7 +96,7 @@ struct allocation_model ...@@ -96,7 +96,7 @@ struct allocation_model
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -267,7 +267,7 @@ struct allocation_model ...@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -39,7 +39,8 @@ struct stream_race ...@@ -39,7 +39,8 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm); MIGRAPHX_EXPORT std::vector<stream_race> analyze_streams(const module& m,
const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT
instruction_ref insert_apply_alpha_beta(module& m, instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos, instruction_ref pos,
const std::vector<instruction_ref>& args, const std::vector<instruction_ref>& args,
......
...@@ -42,7 +42,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -42,7 +42,7 @@ inline namespace MIGRAPHX_INLINE_NS {
* or it can be owned by the argument. * or it can be owned by the argument.
* *
*/ */
struct argument : raw_data<argument> struct MIGRAPHX_EXPORT argument : raw_data<argument>
{ {
argument() = default; argument() = default;
...@@ -93,6 +93,16 @@ struct argument : raw_data<argument> ...@@ -93,6 +93,16 @@ struct argument : raw_data<argument>
/// Return the ith element /// Return the ith element
argument element(std::size_t i) const; argument element(std::size_t i) const;
// Keeps the same data ordering as the given container
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) <= m_shape.elements());
this->visit([&](auto output) {
std::copy(start, end, output.begin());
});
}
private: private:
void assign_buffer(std::function<char*()> d); void assign_buffer(std::function<char*()> d);
struct data_t struct data_t
...@@ -107,9 +117,9 @@ struct argument : raw_data<argument> ...@@ -107,9 +117,9 @@ struct argument : raw_data<argument>
data_t m_data{}; data_t m_data{};
}; };
std::vector<shape> to_shapes(const std::vector<argument>& args); MIGRAPHX_EXPORT std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, argument& a);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -42,7 +42,7 @@ void any_cast() ...@@ -42,7 +42,7 @@ void any_cast()
template <class T> template <class T>
struct auto_any_caster struct auto_any_caster
{ {
T& x; T& x; // NOLINT
template <class U> template <class U>
operator U&() operator U&()
......
...@@ -33,7 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,7 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct auto_contiguous struct MIGRAPHX_EXPORT auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -90,7 +90,17 @@ struct param ...@@ -90,7 +90,17 @@ struct param
struct returns struct returns
{ {
std::string name() const { return "@return"; } std::string name() const { return "@return"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
shape compute_shape(const std::vector<shape>& arg) const
{
if(arg.empty())
return {};
else if(arg.size() == 1)
return arg[0];
else
return arg;
}
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPHX_THROW("builtin"); MIGRAPHX_THROW("builtin");
......
...@@ -38,8 +38,8 @@ struct check_shapes ...@@ -38,8 +38,8 @@ struct check_shapes
{ {
const shape* begin; const shape* begin;
const shape* end; const shape* end;
const std::string name; std::string name;
const bool dynamic_allowed; bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false) check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d) : begin(b), end(e), name(n), dynamic_allowed(d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment