Commit ca17bcd6 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Enable scheduler for 1 stream (#399)

* Enable scheduler for 1 stream

* Formatting

* Improve performance of sorting

* Formatting

* Adjust the weight calculation

* Formatting

* Simplify formula

* Formatting

* Avoid division by zero

* Fix scheduler test

* Check for either 1 or 2

* Check for waits when order may change

* Formatting
parent 9799d373
...@@ -21,11 +21,11 @@ bool disabled(const char* name) ...@@ -21,11 +21,11 @@ bool disabled(const char* name)
return contains({"0", "disable", "disabled", "no", "false"}, e.front()); return contains({"0", "disable", "disabled", "no", "false"}, e.front());
} }
std::size_t value_of(const char* name) std::size_t value_of(const char* name, std::size_t fallback)
{ {
auto e = env(name); auto e = env(name);
if(e.empty()) if(e.empty())
return 0; return fallback;
return std::stoul(e.front()); return std::stoul(e.front());
} }
......
...@@ -19,7 +19,7 @@ bool enabled(const char* name); ...@@ -19,7 +19,7 @@ bool enabled(const char* name);
bool disabled(const char* name); bool disabled(const char* name);
std::vector<std::string> env(const char* name); std::vector<std::string> env(const char* name);
std::size_t value_of(const char* name); std::size_t value_of(const char* name, std::size_t fallback = 0);
template <class T> template <class T>
bool enabled(T) bool enabled(T)
...@@ -36,9 +36,9 @@ bool disabled(T) ...@@ -36,9 +36,9 @@ bool disabled(T)
} }
template <class T> template <class T>
std::size_t value_of(T) std::size_t value_of(T, std::size_t fallback = 0)
{ {
static const std::size_t result = value_of(T::value()); static const std::size_t result = value_of(T::value(), fallback);
return result; return result;
} }
......
...@@ -384,25 +384,26 @@ argument generic_eval(const program& p, ...@@ -384,25 +384,26 @@ argument generic_eval(const program& p,
values.reserve(16); values.reserve(16);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->name() == "@literal") const auto& name = ins->name();
if(name == "@literal")
{ {
results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); })); results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); }));
} }
else if(ins->name() == "@param") else if(name == "@param")
{ {
results.emplace( results.emplace(
ins, trace(ins, [&] { ins, trace(ins, [&] {
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params.at(param_name); auto param = params[param_name];
if(param.get_shape() != ins->get_shape()) if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name);
return param; return param;
})); }));
} }
else if(ins->name() == "@outline") else if(name == "@outline")
{ {
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; })); results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <queue>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <set> #include <set>
...@@ -18,6 +19,8 @@ ...@@ -18,6 +19,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SCHEDULE)
auto get_inputs() auto get_inputs()
{ {
return [](auto i) { return i->inputs(); }; return [](auto i) { return i->inputs(); };
...@@ -54,6 +57,17 @@ struct stream_info ...@@ -54,6 +57,17 @@ struct stream_info
})(last); })(last);
} }
template <class Compare>
void sort_args_by_weight(std::vector<instruction_ref>& args, Compare compare) const
{
if(args.size() < 2)
return;
std::sort(args.begin(), args.end(), by(compare, [this](auto x) {
return std::make_tuple(
this->weights.at(x), x->inputs().size(), std::addressof(*x));
}));
}
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args) std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{ {
if(args.size() < 2) if(args.size() < 2)
...@@ -61,11 +75,8 @@ struct stream_info ...@@ -61,11 +75,8 @@ struct stream_info
return args.end(); return args.end();
} }
const std::size_t min_partition_threshold = 2; const std::size_t min_partition_threshold = 1;
auto compare = by(std::greater<>{}, [&](auto x) { sort_args_by_weight(args, std::greater<>{});
return std::make_tuple(this->weights[x], x->inputs().size());
});
std::sort(args.begin(), args.end(), compare);
auto it = std::lower_bound(std::next(args.begin()), auto it = std::lower_bound(std::next(args.begin()),
args.end(), args.end(),
...@@ -89,8 +100,9 @@ struct stream_info ...@@ -89,8 +100,9 @@ struct stream_info
} }
}; };
void assign_streams(program& p, std::size_t n) std::size_t assign_streams(program& p, std::size_t n)
{ {
assert(n > 0);
partition critical; partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
...@@ -126,19 +138,77 @@ struct stream_info ...@@ -126,19 +138,77 @@ struct stream_info
// Set the critical partition to stream 0 // Set the critical partition to stream 0
set_stream(critical, 0); set_stream(critical, 0);
std::vector<std::size_t> streams(n - 1); if(n == 1)
// Assign streams for the other partitions {
for(auto&& ins_part : partitions) // Assign streams for the other partitions
for(auto&& ins_part : partitions)
for(auto&& part : ins_part.second)
set_stream(part, 0);
return 1;
}
else
{ {
std::sort( std::vector<std::size_t> streams(n - 1);
ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { // Assign streams for the other partitions
return std::make_tuple(x.weight, x.instructions.size()); for(auto&& ins_part : partitions)
}));
for(auto&& part : ins_part.second)
{ {
auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin(); std::sort(ins_part.second.begin(),
set_stream(part, stream + 1); ins_part.second.end(),
streams[stream] += part.weight; by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
{
auto stream =
std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream + 1);
streams[stream] += part.weight;
}
}
return 1 + std::count_if(streams.begin(), streams.end(), [](auto x) { return x > 0; });
}
}
using weight_ins = std::pair<std::size_t, instruction_ref>;
struct compare_weight_ins
{
bool operator()(const weight_ins& x, const weight_ins& y) const
{
return std::make_pair(x.first, std::addressof(*x.second)) <
std::make_pair(y.first, std::addressof(*y.second));
}
};
void sort(program& p, std::size_t) const
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end());
auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1);
auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins);
auto& v = visited[ins];
auto it = children.find(std::make_pair(v * w, ins));
if(it == children.end())
{
v++;
children.insert(std::make_pair(v * w, ins));
}
};
add_child(last);
while(not children.empty())
{
// Pop the first element
auto top = children.begin()->second;
children.erase(children.begin());
p.move_instruction(top, p.begin());
for(auto ins : top->inputs())
{
add_child(ins);
} }
} }
} }
...@@ -398,9 +468,10 @@ void schedule::apply(program& p) const ...@@ -398,9 +468,10 @@ void schedule::apply(program& p) const
stream_info si; stream_info si;
auto last = std::prev(p.end()); auto last = std::prev(p.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
si.assign_streams(p, model.concurrency()); auto nstreams = si.assign_streams(p, model.concurrency());
si.sort(p, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{ {
p.annotate(std::cout, [&](auto ins) { p.annotate(std::cout, [&](auto ins) {
std::cout << ":"; std::cout << ":";
...@@ -417,6 +488,10 @@ void schedule::apply(program& p) const ...@@ -417,6 +488,10 @@ void schedule::apply(program& p) const
std::cout << std::endl; std::cout << std::endl;
} }
// No concurrency
if(nstreams < 2)
return;
// Schedule instructions // Schedule instructions
std::size_t wait_id = 0; std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait; std::unordered_map<instruction_ref, std::size_t> ins2wait;
......
...@@ -12,6 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,6 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_NSTREAMS)
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
...@@ -126,7 +127,7 @@ struct hip_device ...@@ -126,7 +127,7 @@ struct hip_device
struct context struct context
{ {
context(std::size_t device_id = 0, std::size_t n = 4) context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n)) : current_device(std::make_shared<hip_device>(device_id, n))
{ {
} }
......
...@@ -97,11 +97,10 @@ static std::unordered_map<std::string, std::size_t> create_weight_map() ...@@ -97,11 +97,10 @@ static std::unordered_map<std::string, std::size_t> create_weight_map()
{ {
return {{"hip::load_literal", 0}, return {{"hip::load_literal", 0},
{"hip::allocate", 0}, {"hip::allocate", 0},
{"gpu::convolution", 4}, {"gpu::convolution", 8},
{"gpu::conv_bias_relu", 4}, {"gpu::conv_bias_relu", 8},
{"gpu::pooling", 2}, {"gpu::pooling", 4},
{"gpu::gemm", 2}, {"gpu::gemm", 4}};
{"gpu::concat", 1}};
} }
static const std::unordered_map<std::string, std::size_t>& weight_map() static const std::unordered_map<std::string, std::size_t>& weight_map()
...@@ -114,7 +113,7 @@ std::size_t schedule_model::weight(const operation& op) const ...@@ -114,7 +113,7 @@ std::size_t schedule_model::weight(const operation& op) const
{ {
if(weight_map().count(op.name()) == 0) if(weight_map().count(op.name()) == 0)
{ {
return 1; return 2;
} }
return weight_map().at(op.name()); return weight_map().at(op.name());
} }
......
...@@ -30,7 +30,7 @@ namespace migraphx { ...@@ -30,7 +30,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
......
...@@ -116,6 +116,7 @@ struct lhs_expression ...@@ -116,6 +116,7 @@ struct lhs_expression
TEST_LHS_REOPERATOR(%) TEST_LHS_REOPERATOR(%)
TEST_LHS_REOPERATOR(&) TEST_LHS_REOPERATOR(&)
TEST_LHS_REOPERATOR(|) TEST_LHS_REOPERATOR(|)
TEST_LHS_REOPERATOR (^)
TEST_LHS_REOPERATOR(&&) TEST_LHS_REOPERATOR(&&)
TEST_LHS_REOPERATOR(||) TEST_LHS_REOPERATOR(||)
}; };
......
...@@ -280,7 +280,7 @@ TEST_CASE(stream_free) ...@@ -280,7 +280,7 @@ TEST_CASE(stream_free)
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1)); EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2)); EXPECT(not t.has_stream(onep2));
EXPECT(t.get_stream(binary) == 0); EXPECT(not t.has_stream(binary));
} }
TEST_CASE(zero_record) TEST_CASE(zero_record)
...@@ -616,8 +616,9 @@ TEST_CASE(inner_par_merge) ...@@ -616,8 +616,9 @@ TEST_CASE(inner_par_merge)
t.get_stream(outer1), t.get_stream(outer1),
t.get_stream(outer2)})); t.get_stream(outer2)}));
EXPECT(t.get_stream(outer1) == 1); EXPECT(t.get_stream(outer1) != t.get_stream(outer2));
EXPECT(t.get_stream(outer2) == 2); EXPECT(migraphx::contains({1, 2}, t.get_stream(outer1)));
EXPECT(migraphx::contains({1, 2}, t.get_stream(outer2)));
EXPECT(t.get_stream(i1) != t.get_stream(i2)); EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1) for(auto ins : c1)
...@@ -704,9 +705,8 @@ TEST_CASE(inner_split1) ...@@ -704,9 +705,8 @@ TEST_CASE(inner_split1)
EXPECT( EXPECT(
get_wait_for(output) == get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)})); get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
EXPECT(get_wait_for(s1).empty()); // Either s1 or s2 has a wait depending on the sort order but not both
// TODO: Remove the extra wait here EXPECT(get_wait_for(s1).empty() xor get_wait_for(s2).empty());
// EXPECT(get_wait_for(s2).empty());
t.check_conflicts(p, {c1, {i1}, {s1}, {s2}}); t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment