Commit 644f0bc6 authored by Paul's avatar Paul
Browse files

Add output during tracing

parent 0f8c0360
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/tracer.hpp> #include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -16,6 +17,9 @@ ...@@ -16,6 +17,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl; struct program_impl;
const operation& get_operation(instruction_ref ins); const operation& get_operation(instruction_ref ins);
...@@ -107,6 +111,8 @@ struct program ...@@ -107,6 +111,8 @@ struct program
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl struct program_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
...@@ -532,6 +529,13 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const ...@@ -532,6 +529,13 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; }); generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
} }
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
print_program(os, *this, [&](auto ins, auto&&) {
a(ins);
});
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
......
...@@ -224,6 +224,17 @@ void schedule::apply(program& p) const ...@@ -224,6 +224,17 @@ void schedule::apply(program& p) const
self(i); self(i);
})(last); })(last);
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
{
p.annotate(std::cout, [&](auto ins) {
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
if (si.has_stream(ins))
std::cout << " stream=" << si.get_stream(ins);
});
std::cout << std::endl;
}
// Schedule instructions // Schedule instructions
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -120,14 +120,13 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx ...@@ -120,14 +120,13 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx
return false; return false;
} }
void check_conflicts(migraphx::program& p, void check_conflicts(migraphx::program& p, std::vector<std::vector<migraphx::instruction_ref>> conflicts)
std::vector<std::vector<migraphx::instruction_ref>> conflicts)
{ {
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) { migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j) if (i == j)
return; return;
for(auto ins1 : conflicts[i]) for(auto ins1:conflicts[i])
for(auto ins2 : conflicts[j]) for(auto ins2:conflicts[j])
CHECK(check_conflicts(p, ins1, ins2)); CHECK(check_conflicts(p, ins1, ins2));
}); });
} }
...@@ -149,12 +148,11 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins) ...@@ -149,12 +148,11 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
return wf; return wf;
} }
template <class T> template<class T>
std::vector<migraphx::instruction_ref> std::vector<migraphx::instruction_ref> chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
{ {
std::vector<migraphx::instruction_ref> result; std::vector<migraphx::instruction_ref> result;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0;i < n;i++)
{ {
result.push_back(p.add_instruction(x, input)); result.push_back(p.add_instruction(x, input));
input = result.back(); input = result.back();
...@@ -207,7 +205,7 @@ TEST_CASE(two_weights) ...@@ -207,7 +205,7 @@ TEST_CASE(two_weights)
p.compile(schedule_target{&stream}); p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0); EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 1); EXPECT(stream.at(i1) == 1);
for(auto ins : c1) for(auto ins:c1)
EXPECT(stream.at(ins) == 0); EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary) == 0); EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]})); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]}));
...@@ -223,20 +221,15 @@ TEST_CASE(four_weights) ...@@ -223,20 +221,15 @@ TEST_CASE(four_weights)
auto c2 = chain(p, 3, unary_op{}, one); auto c2 = chain(p, 3, unary_op{}, one);
auto c3 = chain(p, 2, unary_op{}, one); auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(schedule_target{&stream}); p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0); EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 3); EXPECT(stream.at(i1) == 3);
for(auto ins : c1) for(auto ins:c1) EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(ins) == 0); for(auto ins:c2) EXPECT(stream.at(ins) == 1);
for(auto ins : c2) for(auto ins:c3) EXPECT(stream.at(ins) == 2);
EXPECT(stream.at(ins) == 1);
for(auto ins : c3)
EXPECT(stream.at(ins) == 2);
EXPECT(stream.at(binary) == 0); EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
get_wait_for(stream[binary],
{stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
check_conflicts(p, {c1, c2, c3, {i1}}); check_conflicts(p, {c1, c2, c3, {i1}});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment