Unverified Commit 072fd5cc authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable eval to handle multiple contexts (#1751)

This is to help enable multi-target execution. We store a vector of targets and contexts. Currently this will only compile a single target, the PR #1672 is needed to enable multiple targets.

This will also serialize the targets and contexts.

When using the execution_environment or prog.get_context() it will always use the context from the first target assuming this is the "primary" target. Although, its unlikely a user would use execution_environment with a multi-target environment.
parent 697709a7
...@@ -94,6 +94,7 @@ add_library(migraphx ...@@ -94,6 +94,7 @@ add_library(migraphx
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp split_single_dyn_dim.cpp
target.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -137,6 +137,7 @@ struct instruction ...@@ -137,6 +137,7 @@ struct instruction
operation normalized_operator() const; operation normalized_operator() const;
std::size_t get_target_id() const; std::size_t get_target_id() const;
void set_target_id(std::size_t tid); void set_target_id(std::size_t tid);
void debug_print() const; void debug_print() const;
......
...@@ -189,7 +189,7 @@ struct module ...@@ -189,7 +189,7 @@ struct module
instruction_ref validate() const; instruction_ref validate() const;
instruction_ref find_dangling_reference() const; instruction_ref find_dangling_reference() const;
void finalize(context& ctx); void finalize(std::vector<context>& contexts);
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
......
...@@ -261,11 +261,13 @@ auto compute_op(rank<1>, ...@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F> template <class T, class F>
argument compute_op(rank<0>, argument compute_op(rank<0>,
const T& x, const T& x,
const shape&, const shape& output,
const std::vector<argument>&, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>& module_args,
F) F)
{ {
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
......
...@@ -79,6 +79,9 @@ struct program ...@@ -79,6 +79,9 @@ struct program
std::vector<argument> eval(parameter_map params, std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const; execution_environment exec_env = execution_environment{}) const;
void finish() const;
std::size_t size() const; std::size_t size() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
......
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for a compilation target /// An interface for a compilation target
...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x) ...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const ...@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const
return o; return o;
} }
std::size_t instruction::get_target_id() const { return target_id; } std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; } void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args) std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
......
...@@ -652,8 +652,9 @@ instruction_ref module::find_dangling_reference() const ...@@ -652,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
void module::finalize(context& ctx) void module::finalize(std::vector<context>& contexts)
{ {
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
...@@ -662,10 +663,10 @@ void module::finalize(context& ctx) ...@@ -662,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: "; std::cout << "Finalize: ";
this->debug_print(ins); this->debug_print(ins);
} }
ins->finalize(ctx); ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
smod->finalize(ctx); smod->finalize(contexts);
} }
} }
......
...@@ -70,9 +70,8 @@ struct program_impl ...@@ -70,9 +70,8 @@ struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules; std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
std::vector<context> contexts; std::vector<context> contexts;
std::vector<target> targets;
}; };
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
...@@ -96,14 +95,8 @@ void program::assign(const program& p) ...@@ -96,14 +95,8 @@ void program::assign(const program& p)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx; *impl = *p.impl;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
// build a map from old ins to new ins // build a map from old ins to new ins
// Build a map from old module to new module // Build a map from old module to new module
...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const ...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes(); return mm->get_output_shapes();
} }
context& program::get_context() const { return impl->ctx; } context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p; return p;
} }
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts) void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{ {
...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() + MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index)); " from instruction " + std::to_string(index));
} }
current_mod->finalize(this->impl->contexts[root_target_id]);
} }
} }
this->finalize();
} }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
{ {
// todo: combine with multi-target compile method // todo: combine with multi-target compile method
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->targets = {t};
this->impl->ctx = t.get_context(); this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options) ...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index));
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->contexts);
} }
} }
void program::finalize() void program::finalize()
{ {
auto* mm = this->get_main_module(); auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx); mm->finalize(this->impl->contexts);
} }
template <class T> template <class T>
...@@ -429,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -429,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F make_trace) F trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end()); assert(results.find(ins) == results.end());
...@@ -494,13 +490,18 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -494,13 +490,18 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx; return generic_eval(smod, ctx, inputs, results, trace);
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(
return ins->normalized_operator().compute( ins, trace(ins, [&] {
ctx, ins->get_shape(), values, mod_args, module_eval); auto op = ins->normalized_operator();
if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
...@@ -514,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -514,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F> template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F make_trace) F trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, trace);
} }
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& contexts = this->impl->contexts;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret; std::vector<argument> ret;
if(exec_env.async) if(exec_env.async)
{ {
ctx.wait_for(exec_env.queue); assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
} }
if(trace_level > 0) if(trace_level > 0)
...@@ -563,16 +545,12 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -563,16 +545,12 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names); instruction::print(ss, x, ins_names);
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
ret = generic_eval( auto& ctx = contexts[ins->get_target_id()];
*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
auto result = check_context(f); auto result = f();
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
ctx.finish(); ctx.finish();
double t2 = t.record<milliseconds>(); double t2 = t.record<milliseconds>();
...@@ -583,7 +561,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -583,7 +561,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
migraphx::argument buffer; migraphx::argument buffer;
try try
{ {
target tgt = make_target(this->impl->target_name); const target& tgt = this->impl->targets.at(ins->get_target_id());
buffer = tgt.copy_from(result); buffer = tgt.copy_from(result);
} }
catch(const migraphx::exception&) catch(const migraphx::exception&)
...@@ -593,8 +571,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -593,8 +571,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
} }
catch(...) catch(...)
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
} }
if(trace_level == 2) if(trace_level == 2)
{ {
...@@ -611,35 +588,36 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -611,35 +588,36 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
} }
} }
return result; return result;
})); });
} }
else else
{ {
ret = generic_eval(*this, ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
if(exec_env.async) if(exec_env.async)
{ {
ctx.finish_on(exec_env.queue); assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
} }
return ret; return ret;
} }
const int program_file_version = 5; void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
const int program_file_version = 6;
value program::to_value() const value program::to_value() const
{ {
value result; value result;
result["version"] = program_file_version; result["version"] = program_file_version;
result["target"] = this->impl->target_name; result["targets"] = migraphx::to_value(this->impl->targets);
if(not this->impl->target_name.empty()) result["contexts"] = migraphx::to_value(this->impl->contexts);
result["context"] = this->impl->ctx.to_value();
value module_vals = value::object{}; value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
...@@ -768,12 +746,12 @@ void program::from_value(const value& v) ...@@ -768,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW("Warning: Program version mismatch"); MIGRAPHX_THROW("Warning: Program version mismatch");
} }
this->impl->target_name = v.at("target").to<std::string>(); migraphx::from_value(v.at("targets"), this->impl->targets);
if(not this->impl->target_name.empty())
for(auto i : range(this->impl->targets.size()))
{ {
target t = make_target(this->impl->target_name); this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->ctx = t.get_context(); this->impl->contexts.back().from_value(v.at("contexts")[i]);
this->impl->ctx.from_value(v.at("context"));
} }
auto module_vals = v.at("modules"); auto module_vals = v.at("modules");
...@@ -794,6 +772,8 @@ void program::from_value(const value& v) ...@@ -794,6 +772,8 @@ void program::from_value(const value& v)
auto* mm = get_main_module(); auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods); mod_from_val(mm, module_vals, map_insts, map_mods);
// Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize(); this->finalize();
} }
...@@ -814,19 +794,19 @@ std::string perf_group(const operation& op) ...@@ -814,19 +794,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m) void program::mark(const parameter_map& params, marker&& m)
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Start marking // Start marking
m.mark_start(*this); m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
m.mark_start(ins); m.mark_start(ins);
result = f(); result = f();
m.mark_stop(ins); m.mark_stop(ins);
return result; return result;
})); });
m.mark_stop(*this); m.mark_stop(*this);
} }
...@@ -835,10 +815,10 @@ void program::perf_report(std::ostream& os, ...@@ -835,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map params, parameter_map params,
std::size_t batch) const std::size_t batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
...@@ -846,28 +826,28 @@ void program::perf_report(std::ostream& os, ...@@ -846,28 +826,28 @@ void program::perf_report(std::ostream& os,
{ {
total_vec.push_back(time<milliseconds>([&] { total_vec.push_back(time<milliseconds>([&] {
eval(params); eval(params);
ctx.finish(); this->finish();
})); }));
} }
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) { generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); this->impl->contexts[ins->get_target_id()].finish();
})); }));
return result; return result;
})); });
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -1035,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const ...@@ -1035,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) { generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -41,7 +41,7 @@ TEST_CASE(simple_test) ...@@ -41,7 +41,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop) ...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2) ...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1) ...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2) ...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, one, two); mm->add_instruction(migraphx::make_op("sub"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
...@@ -121,11 +121,11 @@ TEST_CASE(depth_test) ...@@ -121,11 +121,11 @@ TEST_CASE(depth_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two); auto x1 = mm->add_instruction(migraphx::make_op("add"), one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two); auto x2 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test) ...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
......
...@@ -45,7 +45,7 @@ TEST_CASE(simple_test) ...@@ -45,7 +45,7 @@ TEST_CASE(simple_test)
auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one); auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two); auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity); mm->add_instruction(migraphx::make_op("add"), one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end) ...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency) ...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency)
auto one = mm->add_literal(1.0); auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0); auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0); auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, ans, three); mm->add_instruction(migraphx::make_op("add"), ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -49,7 +50,7 @@ struct id_target ...@@ -49,7 +50,7 @@ struct id_target
struct id_ctx_op struct id_ctx_op
{ {
std::string name() const { return "id_ctx_op"; } std::string name() const { return ""; }
migraphx::argument migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1) ...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2) ...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5}); EXPECT(result == migraphx::literal{5});
...@@ -182,7 +183,7 @@ TEST_CASE(print_test) ...@@ -182,7 +183,7 @@ TEST_CASE(print_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two); mm->add_instruction(migraphx::make_op("add"), x, two);
std::stringstream ss; std::stringstream ss;
ss << p; ss << p;
...@@ -197,7 +198,7 @@ TEST_CASE(param_test) ...@@ -197,7 +198,7 @@ TEST_CASE(param_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()}, auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}}) {"y", migraphx::literal{2}.get_argument()}})
.back(); .back();
...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test) ...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({ p.eval({
...@@ -245,7 +246,7 @@ TEST_CASE(get_param1) ...@@ -245,7 +246,7 @@ TEST_CASE(get_param1)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y}); EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
...@@ -257,7 +258,7 @@ TEST_CASE(get_param2) ...@@ -257,7 +258,7 @@ TEST_CASE(get_param2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
} }
...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes) ...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto m = p.get_parameter_shapes(); auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0); EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s); EXPECT(m.at("x") == s);
...@@ -281,8 +282,8 @@ TEST_CASE(replace_test) ...@@ -281,8 +282,8 @@ TEST_CASE(replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->replace_instruction(sum, minus_op{}, two, one); mm->replace_instruction(sum, migraphx::make_op("sub"), two, one);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test) ...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->replace_instruction(sum, minus); mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2) ...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(pass_op{}, minus); mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum); mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test) ...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one); auto sum = mm->add_instruction(migraphx::make_op("add"), two, one);
sum->replace(minus_op{}); sum->replace(migraphx::make_op("sub"));
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw) ...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); })); EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
} }
...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test) ...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two); auto sum0 = mm->insert_instruction(sum1, migraphx::make_op("add"), two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two); mm->replace_instruction(sum1, migraphx::make_op("sub"), sum0, two);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1) ...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), sum, one);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2) ...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -404,7 +405,7 @@ TEST_CASE(target_test) ...@@ -404,7 +405,7 @@ TEST_CASE(target_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
......
...@@ -86,6 +86,25 @@ struct minus_op ...@@ -86,6 +86,25 @@ struct minus_op
}; };
struct pass_op struct pass_op
{
std::string name() const { return "pass"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
};
struct non_const_pass_op
{ {
std::string name() const { return "pass"; } std::string name() const { return "pass"; }
migraphx::argument migraphx::argument
...@@ -176,9 +195,7 @@ struct pass_standard_op ...@@ -176,9 +195,7 @@ struct pass_standard_op
struct nop struct nop
{ {
std::string name() const { return "nop"; } std::string name() const { return "nop"; }
migraphx::argument compute(migraphx::context&, migraphx::argument compute(const migraphx::shape&, const std::vector<migraphx::argument>&) const
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
return {}; return {};
} }
......
...@@ -40,12 +40,12 @@ TEST_CASE(const_add) ...@@ -40,12 +40,12 @@ TEST_CASE(const_add)
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum); m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
auto total = m2.add_literal(3); auto total = m2.add_literal(3);
m2.add_instruction(pass_op{}, total); m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -55,12 +55,12 @@ TEST_CASE(const_add_parameter) ...@@ -55,12 +55,12 @@ TEST_CASE(const_add_parameter)
auto one = m1.add_parameter("one", {migraphx::shape::int32_type, {1}}); auto one = m1.add_parameter("one", {migraphx::shape::int32_type, {1}});
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum); m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
auto total = m2.add_literal(3); auto total = m2.add_literal(3);
m2.add_instruction(pass_op{}, total); m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 != m2); EXPECT(m1 != m2);
} }
...@@ -71,12 +71,12 @@ TEST_CASE(const_multiadd) ...@@ -71,12 +71,12 @@ TEST_CASE(const_multiadd)
auto two = m1.add_literal(2); auto two = m1.add_literal(2);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
m1.add_instruction(pass_op{}, sum2); m1.add_instruction(non_const_pass_op{}, sum2);
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
auto total = m2.add_literal(5); auto total = m2.add_literal(5);
m2.add_instruction(pass_op{}, total); m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -88,12 +88,12 @@ TEST_CASE(const_add_mul) ...@@ -88,12 +88,12 @@ TEST_CASE(const_add_mul)
auto mul = m1.add_instruction(migraphx::make_op("mul"), two, two); auto mul = m1.add_instruction(migraphx::make_op("mul"), two, two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, mul); auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, mul);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two); auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
m1.add_instruction(pass_op{}, sum2); m1.add_instruction(non_const_pass_op{}, sum2);
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
auto total = m2.add_literal(7); auto total = m2.add_literal(7);
m2.add_instruction(pass_op{}, total); m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -105,13 +105,13 @@ TEST_CASE(const_add_scalar) ...@@ -105,13 +105,13 @@ TEST_CASE(const_add_scalar)
auto two = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), auto two = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m1.add_literal(2)); m1.add_literal(2));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum); m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
auto total = auto total =
m2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}}); m2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
m2.add_instruction(pass_op{}, total); m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -121,7 +121,7 @@ TEST_CASE(const_scalar) ...@@ -121,7 +121,7 @@ TEST_CASE(const_scalar)
{ {
auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m1.add_literal(1)); m1.add_literal(1));
m1.add_instruction(pass_op{}, one); m1.add_instruction(non_const_pass_op{}, one);
} }
run_pass(m1); run_pass(m1);
...@@ -129,7 +129,7 @@ TEST_CASE(const_scalar) ...@@ -129,7 +129,7 @@ TEST_CASE(const_scalar)
{ {
auto one = m2.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), auto one = m2.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m2.add_literal(1)); m2.add_literal(1));
m2.add_instruction(pass_op{}, one); m2.add_instruction(non_const_pass_op{}, one);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include <rob.hpp> #include <rob.hpp>
...@@ -33,7 +34,7 @@ TEST_CASE(simple_test) ...@@ -33,7 +34,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{mm->validate() == mm->end()}); EXPECT(bool{mm->validate() == mm->end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.back() == migraphx::literal{3}); EXPECT(result.back() == migraphx::literal{3});
...@@ -46,7 +47,7 @@ TEST_CASE(out_of_order) ...@@ -46,7 +47,7 @@ TEST_CASE(out_of_order)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two); auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->move_instruction(two, mm->end()); mm->move_instruction(two, mm->end());
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
...@@ -57,7 +58,7 @@ TEST_CASE(incomplete_args) ...@@ -57,7 +58,7 @@ TEST_CASE(incomplete_args)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two); auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
ins->clear_arguments(); ins->clear_arguments();
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
...@@ -73,7 +74,7 @@ TEST_CASE(invalid_args) ...@@ -73,7 +74,7 @@ TEST_CASE(invalid_args)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two); auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
access_ins_arguments(*ins).clear(); access_ins_arguments(*ins).clear();
EXPECT(bool{mm->validate() == mm->begin()}); EXPECT(bool{mm->validate() == mm->begin()});
} }
......
...@@ -261,11 +261,13 @@ auto compute_op(rank<1>, ...@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F> template <class T, class F>
argument compute_op(rank<0>, argument compute_op(rank<0>,
const T& x, const T& x,
const shape&, const shape& output,
const std::vector<argument>&, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>& module_args,
F) F)
{ {
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
......
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for a compilation target /// An interface for a compilation target
...@@ -123,11 +125,20 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric) ...@@ -123,11 +125,20 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
} }
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns = 'std::string', const = True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes',
virtual('get_context', returns='context', const=True), ctx = 'context&',
virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'), options = 'const compile_options&',
returns = 'std::vector<pass>',
const = True),
virtual('get_context', returns = 'context', const = True),
virtual('find_supported',
returns = 'supported_segments',
mod = 'const_module_ref',
m = 'support_metric',
const = True,
default = 'target_find_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
...@@ -138,13 +149,17 @@ interface('target', ...@@ -138,13 +149,17 @@ interface('target',
input = 'const argument&', input = 'const argument&',
const = True, const = True,
default = 'copy_from_target'), default = 'copy_from_target'),
virtual('allocate', s='const shape&', returns='argument', const=True, virtual('allocate',
default = 'target_allocate') s = 'const shape&',
) returns = 'argument',
%> const = True,
default = 'target_allocate')) %>
#endif #endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment