Unverified Commit bce629f1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into gru_operator

parents ce7b4b17 fb8fda8f
...@@ -105,6 +105,8 @@ struct program ...@@ -105,6 +105,8 @@ struct program
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void debug_print(const std::vector<instruction_ref>& inss) const;
void dry_run(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output ...@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y) bool operator==(const instruction& x, const instruction& y)
{ {
if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments)) if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
return false; return false;
if(x.name() == "@literal") if(x.name() == "@literal")
return x.lit == y.lit; return x.lit == y.lit;
......
...@@ -379,20 +379,31 @@ argument generic_eval(const program& p, ...@@ -379,20 +379,31 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
};
#else
auto check_context = [](auto f) { return f(); };
#endif
if(enabled(MIGRAPHX_TRACE_EVAL{})) if(enabled(MIGRAPHX_TRACE_EVAL{}))
{ {
auto& ctx = this->impl->ctx; return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::cout << "Run instruction: ";
this->debug_print(ins); this->debug_print(ins);
return f(); return check_context(f);
}); });
} }
else else
{ {
return generic_eval( return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); }); *this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
} }
} }
...@@ -446,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -446,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec.reserve(n); overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
overhead_vec.push_back(time<milliseconds>( overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); }));
[&] { generic_eval(*this, ctx, params, [](auto...) { return argument{}; }); }));
} }
double total_time = common_average(total_vec); double total_time = common_average(total_vec);
...@@ -511,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const ...@@ -511,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl; std::cout << std::endl;
} }
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
......
...@@ -122,7 +122,14 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -122,7 +122,14 @@ migraphx::argument run_gpu(migraphx::program& p)
m[x.first] = m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first))); migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
} }
// Program should have an output parameter
EXPECT(bool{m.find("output") != m.end()}); EXPECT(bool{m.find("output") != m.end()});
// Ensure the program doesn't modify the context in a dry run
auto ctx = p.get_context();
assert(&ctx != &p.get_context());
EXPECT(is_shared(ctx, p.get_context()));
p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context()));
return migraphx::gpu::from_gpu(p.eval(m)); return migraphx::gpu::from_gpu(p.eval(m));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment