Unverified Commit 072fd5cc authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable eval to handle multiple contexts (#1751)

This is to help enable multi-target execution. We store a vector of targets and contexts. Currently this will only compile a single target, the PR #1672 is needed to enable multiple targets.

This will also serialize the targets and contexts.

When using the execution_environment or prog.get_context() it will always use the context from the first target assuming this is the "primary" target. Although, its unlikely a user would use execution_environment with a multi-target environment.
parent 697709a7
......@@ -94,6 +94,7 @@ add_library(migraphx
simplify_algebra.cpp
simplify_reshapes.cpp
split_single_dyn_dim.cpp
target.cpp
tmp_dir.cpp
value.cpp
verify_args.cpp
......
......@@ -137,6 +137,7 @@ struct instruction
operation normalized_operator() const;
std::size_t get_target_id() const;
void set_target_id(std::size_t tid);
void debug_print() const;
......
......@@ -189,7 +189,7 @@ struct module
instruction_ref validate() const;
instruction_ref find_dangling_reference() const;
void finalize(context& ctx);
void finalize(std::vector<context>& contexts);
void debug_print() const;
void debug_print(instruction_ref ins) const;
......
......@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F)
{
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
......
......@@ -79,6 +79,9 @@ struct program
std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const;
void finish() const;
std::size_t size() const;
std::vector<shape> get_output_shapes() const;
......
......@@ -45,6 +45,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN
/// An interface for a compilation target
......@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const
return o;
}
std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
......
......@@ -652,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end();
}
void module::finalize(context& ctx)
void module::finalize(std::vector<context>& contexts)
{
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this))
{
......@@ -662,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx);
ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs())
{
smod->finalize(ctx);
smod->finalize(contexts);
}
}
......
......@@ -70,9 +70,8 @@ struct program_impl
{
// A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
std::vector<context> contexts;
std::vector<target> targets;
};
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
......@@ -96,14 +95,8 @@ void program::assign(const program& p)
{
impl = std::make_unique<program_impl>();
}
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
*impl = *p.impl;
// build a map from old ins to new ins
// Build a map from old module to new module
......@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes();
}
context& program::get_context() const { return impl->ctx; }
context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const
{
......@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); }
bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{
......@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index));
}
current_mod->finalize(this->impl->contexts[root_target_id]);
}
}
this->finalize();
}
void program::compile(const target& t, compile_options options)
{
// todo: combine with multi-target compile method
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
this->impl->targets = {t};
this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
......@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->ctx);
mod->finalize(this->impl->contexts);
}
}
void program::finalize()
{
auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx);
mm->finalize(this->impl->contexts);
}
template <class T>
......@@ -429,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F>
std::vector<argument> generic_eval(const module* mod,
context& ctx,
std::vector<context>& ctx,
std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results,
F make_trace)
F trace)
{
assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2);
std::vector<argument> values;
values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod))
{
assert(results.find(ins) == results.end());
......@@ -494,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
return generic_eval(smod, ctx, inputs, results, trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values, mod_args, module_eval);
}));
results.emplace(
ins, trace(ins, [&] {
auto op = ins->normalized_operator();
if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
}));
}
assert(results.find(ins) != results.end());
if(not ins->get_shape().any_of_dynamic())
......@@ -514,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F>
std::vector<argument> generic_eval(const program& p,
context& ctx,
std::vector<context>& ctx,
std::unordered_map<std::string, argument> params,
F make_trace)
F trace)
{
const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace);
return generic_eval(mm, ctx, params, {}, trace);
}
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto& contexts = this->impl->contexts;
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;
if(exec_env.async)
{
ctx.wait_for(exec_env.queue);
assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
}
if(trace_level > 0)
......@@ -563,83 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
ret = generic_eval(
*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
not result.empty())
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
auto& ctx = contexts[ins->get_target_id()];
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = f();
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
not result.empty())
{
migraphx::argument buffer;
try
{
migraphx::argument buffer;
try
{
target tgt = make_target(this->impl->target_name);
buffer = tgt.copy_from(result);
}
catch(const migraphx::exception&)
{
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2)
{
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
print_statistics(std::cout, buffer);
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
const target& tgt = this->impl->targets.at(ins->get_target_id());
buffer = tgt.copy_from(result);
}
return result;
}));
catch(const migraphx::exception&)
{
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2)
{
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
print_statistics(std::cout, buffer);
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
});
}
else
{
ret = generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
}
if(exec_env.async)
{
ctx.finish_on(exec_env.queue);
assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
}
return ret;
}
const int program_file_version = 5;
void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
const int program_file_version = 6;
value program::to_value() const
{
value result;
result["version"] = program_file_version;
result["target"] = this->impl->target_name;
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
result["version"] = program_file_version;
result["targets"] = migraphx::to_value(this->impl->targets);
result["contexts"] = migraphx::to_value(this->impl->contexts);
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names;
......@@ -768,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW("Warning: Program version mismatch");
}
this->impl->target_name = v.at("target").to<std::string>();
if(not this->impl->target_name.empty())
migraphx::from_value(v.at("targets"), this->impl->targets);
for(auto i : range(this->impl->targets.size()))
{
target t = make_target(this->impl->target_name);
this->impl->ctx = t.get_context();
this->impl->ctx.from_value(v.at("context"));
this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->contexts.back().from_value(v.at("contexts")[i]);
}
auto module_vals = v.at("modules");
......@@ -794,7 +772,9 @@ void program::from_value(const value& v)
auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize();
// Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize();
}
double common_average(const std::vector<double>& v)
......@@ -814,19 +794,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m)
{
auto& ctx = this->impl->ctx;
auto& ctx = this->impl->contexts;
// Run once by itself
eval(params);
ctx.finish();
this->finish();
// Start marking
m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result;
m.mark_start(ins);
result = f();
m.mark_stop(ins);
return result;
}));
});
m.mark_stop(*this);
}
......@@ -835,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map params,
std::size_t batch) const
{
auto& ctx = this->impl->ctx;
auto& ctx = this->impl->contexts;
// Run once by itself
eval(params);
ctx.finish();
this->finish();
// Run and time entire program
std::vector<double> total_vec;
total_vec.reserve(n);
......@@ -846,28 +826,28 @@ void program::perf_report(std::ostream& os,
{
total_vec.push_back(time<milliseconds>([&] {
eval(params);
ctx.finish();
this->finish();
}));
}
std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) {
generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr};
}));
});
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result;
ins_vec[ins].push_back(time<milliseconds>([&] {
result = f();
ctx.finish();
this->impl->contexts[ins->get_target_id()].finish();
}));
return result;
}));
});
}
for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end());
......@@ -1035,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
});
}
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -41,7 +41,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
......@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
......@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(nop{});
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
......@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
......@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("sub"), one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
......@@ -121,11 +121,11 @@ TEST_CASE(depth_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(sum_op{}, one, two);
auto x1 = mm->add_instruction(migraphx::make_op("add"), one, two);
auto x2 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
......@@ -141,7 +141,7 @@ TEST_CASE(undefined_test)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
......
......@@ -45,7 +45,7 @@ TEST_CASE(simple_test)
auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity);
mm->add_instruction(migraphx::make_op("add"), one_identity, two_identity);
run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
......@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two);
auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency)
auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, ans, three);
auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......
......@@ -27,6 +27,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -49,7 +50,7 @@ struct id_target
struct id_ctx_op
{
std::string name() const { return "id_ctx_op"; }
std::string name() const { return ""; }
migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
......@@ -156,7 +157,7 @@ TEST_CASE(literal_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -168,8 +169,8 @@ TEST_CASE(literal_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5});
......@@ -182,7 +183,7 @@ TEST_CASE(print_test)
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two);
mm->add_instruction(migraphx::make_op("add"), x, two);
std::stringstream ss;
ss << p;
......@@ -197,7 +198,7 @@ TEST_CASE(param_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}})
.back();
......@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({
......@@ -245,7 +246,7 @@ TEST_CASE(get_param1)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
......@@ -257,7 +258,7 @@ TEST_CASE(get_param2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
}
......@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s);
......@@ -281,8 +282,8 @@ TEST_CASE(replace_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->replace_instruction(sum, minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->replace_instruction(sum, migraphx::make_op("sub"), two, one);
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == mm->end()});
......@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == mm->end()});
......@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{});
auto sum = mm->add_instruction(migraphx::make_op("add"), two, one);
sum->replace(migraphx::make_op("sub"));
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
}
......@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, sum1, two);
auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two);
auto sum0 = mm->insert_instruction(sum1, migraphx::make_op("add"), two, two);
mm->replace_instruction(sum1, migraphx::make_op("sub"), sum0, two);
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
......@@ -372,8 +373,8 @@ TEST_CASE(remove_test1)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one);
auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto removed = mm->add_instruction(migraphx::make_op("sub"), sum, one);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()});
......@@ -388,8 +389,8 @@ TEST_CASE(remove_test2)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(migraphx::make_op("add"), one, two);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()});
......@@ -404,7 +405,7 @@ TEST_CASE(target_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......
......@@ -86,6 +86,25 @@ struct minus_op
};
struct pass_op
{
std::string name() const { return "pass"; }
migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
};
struct non_const_pass_op
{
std::string name() const { return "pass"; }
migraphx::argument
......@@ -176,9 +195,7 @@ struct pass_standard_op
struct nop
{
std::string name() const { return "nop"; }
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
migraphx::argument compute(const migraphx::shape&, const std::vector<migraphx::argument>&) const
{
return {};
}
......
......@@ -40,12 +40,12 @@ TEST_CASE(const_add)
auto one = m1.add_literal(1);
auto two = m1.add_literal(2);
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum);
m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1);
migraphx::module m2;
auto total = m2.add_literal(3);
m2.add_instruction(pass_op{}, total);
m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2);
}
......@@ -55,12 +55,12 @@ TEST_CASE(const_add_parameter)
auto one = m1.add_parameter("one", {migraphx::shape::int32_type, {1}});
auto two = m1.add_literal(2);
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum);
m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1);
migraphx::module m2;
auto total = m2.add_literal(3);
m2.add_instruction(pass_op{}, total);
m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 != m2);
}
......@@ -71,12 +71,12 @@ TEST_CASE(const_multiadd)
auto two = m1.add_literal(2);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
m1.add_instruction(pass_op{}, sum2);
m1.add_instruction(non_const_pass_op{}, sum2);
run_pass(m1);
migraphx::module m2;
auto total = m2.add_literal(5);
m2.add_instruction(pass_op{}, total);
m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2);
}
......@@ -88,12 +88,12 @@ TEST_CASE(const_add_mul)
auto mul = m1.add_instruction(migraphx::make_op("mul"), two, two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, mul);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two);
m1.add_instruction(pass_op{}, sum2);
m1.add_instruction(non_const_pass_op{}, sum2);
run_pass(m1);
migraphx::module m2;
auto total = m2.add_literal(7);
m2.add_instruction(pass_op{}, total);
m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2);
}
......@@ -105,13 +105,13 @@ TEST_CASE(const_add_scalar)
auto two = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m1.add_literal(2));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, two);
m1.add_instruction(pass_op{}, sum);
m1.add_instruction(non_const_pass_op{}, sum);
run_pass(m1);
migraphx::module m2;
auto total =
m2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
m2.add_instruction(pass_op{}, total);
m2.add_instruction(non_const_pass_op{}, total);
EXPECT(m1 == m2);
}
......@@ -121,7 +121,7 @@ TEST_CASE(const_scalar)
{
auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m1.add_literal(1));
m1.add_instruction(pass_op{}, one);
m1.add_instruction(non_const_pass_op{}, one);
}
run_pass(m1);
......@@ -129,7 +129,7 @@ TEST_CASE(const_scalar)
{
auto one = m2.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}),
m2.add_literal(1));
m2.add_instruction(pass_op{}, one);
m2.add_instruction(non_const_pass_op{}, one);
}
EXPECT(m1 == m2);
}
......
......@@ -23,6 +23,7 @@
*/
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <rob.hpp>
......@@ -33,7 +34,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{mm->validate() == mm->end()});
auto result = p.eval({});
EXPECT(result.back() == migraphx::literal{3});
......@@ -46,7 +47,7 @@ TEST_CASE(out_of_order)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->move_instruction(two, mm->end());
EXPECT(bool{p.validate() == ins});
}
......@@ -57,7 +58,7 @@ TEST_CASE(incomplete_args)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
ins->clear_arguments();
EXPECT(bool{p.validate() == ins});
}
......@@ -73,7 +74,7 @@ TEST_CASE(invalid_args)
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
auto ins = mm->add_instruction(migraphx::make_op("add"), one, two);
access_ins_arguments(*ins).clear();
EXPECT(bool{mm->validate() == mm->begin()});
}
......
......@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F)
{
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
......
......@@ -45,6 +45,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN
/// An interface for a compilation target
......@@ -123,28 +125,41 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
}
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate', s='const shape&', returns='argument', const=True,
default = 'target_allocate')
)
%>
interface('target',
virtual('name', returns = 'std::string', const = True),
virtual('get_passes',
ctx = 'context&',
options = 'const compile_options&',
returns = 'std::vector<pass>',
const = True),
virtual('get_context', returns = 'context', const = True),
virtual('find_supported',
returns = 'supported_segments',
mod = 'const_module_ref',
m = 'support_metric',
const = True,
default = 'target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate',
s = 'const shape&',
returns = 'argument',
const = True,
default = 'target_allocate')) %>
#endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment