Commit ca17bcd6 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Enable scheduler for 1 stream (#399)

* Enable scheduler for 1 stream

* Formatting

* Improve performance of sorting

* Formatting

* Adjust the weight calculation

* Formatting

* Simplify formula

* Formatting

* Avoid division by zero

* Fix scheduler test

* Check for either 1 or 2

* Check for waits when order may change

* Formatting
parent 9799d373
......@@ -21,11 +21,11 @@ bool disabled(const char* name)
return contains({"0", "disable", "disabled", "no", "false"}, e.front());
}
std::size_t value_of(const char* name)
std::size_t value_of(const char* name, std::size_t fallback)
{
auto e = env(name);
if(e.empty())
return 0;
return fallback;
return std::stoul(e.front());
}
......
......@@ -19,7 +19,7 @@ bool enabled(const char* name);
bool disabled(const char* name);
std::vector<std::string> env(const char* name);
std::size_t value_of(const char* name);
std::size_t value_of(const char* name, std::size_t fallback = 0);
template <class T>
bool enabled(T)
......@@ -36,9 +36,9 @@ bool disabled(T)
}
template <class T>
std::size_t value_of(T)
std::size_t value_of(T, std::size_t fallback = 0)
{
static const std::size_t result = value_of(T::value());
static const std::size_t result = value_of(T::value(), fallback);
return result;
}
......
......@@ -384,25 +384,26 @@ argument generic_eval(const program& p,
values.reserve(16);
for(auto ins : iterator_for(p))
{
if(ins->name() == "@literal")
const auto& name = ins->name();
if(name == "@literal")
{
results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); }));
}
else if(ins->name() == "@param")
else if(name == "@param")
{
results.emplace(
ins, trace(ins, [&] {
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params.at(param_name);
auto param = params[param_name];
if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
return param;
}));
}
else if(ins->name() == "@outline")
else if(name == "@outline")
{
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
}
......
......@@ -9,6 +9,7 @@
#include <migraphx/ranges.hpp>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <thread>
#include <mutex>
#include <set>
......@@ -18,6 +19,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SCHEDULE)
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
......@@ -54,6 +57,17 @@ struct stream_info
})(last);
}
template <class Compare>
void sort_args_by_weight(std::vector<instruction_ref>& args, Compare compare) const
{
if(args.size() < 2)
return;
std::sort(args.begin(), args.end(), by(compare, [this](auto x) {
return std::make_tuple(
this->weights.at(x), x->inputs().size(), std::addressof(*x));
}));
}
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{
if(args.size() < 2)
......@@ -61,11 +75,8 @@ struct stream_info
return args.end();
}
const std::size_t min_partition_threshold = 2;
auto compare = by(std::greater<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
std::sort(args.begin(), args.end(), compare);
const std::size_t min_partition_threshold = 1;
sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()),
args.end(),
......@@ -89,8 +100,9 @@ struct stream_info
}
};
void assign_streams(program& p, std::size_t n)
std::size_t assign_streams(program& p, std::size_t n)
{
assert(n > 0);
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
......@@ -126,19 +138,77 @@ struct stream_info
// Set the critical partition to stream 0
set_stream(critical, 0);
std::vector<std::size_t> streams(n - 1);
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
if(n == 1)
{
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
for(auto&& part : ins_part.second)
set_stream(part, 0);
return 1;
}
else
{
std::sort(
ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
std::vector<std::size_t> streams(n - 1);
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
{
auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream + 1);
streams[stream] += part.weight;
std::sort(ins_part.second.begin(),
ins_part.second.end(),
by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
{
auto stream =
std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream + 1);
streams[stream] += part.weight;
}
}
return 1 + std::count_if(streams.begin(), streams.end(), [](auto x) { return x > 0; });
}
}
using weight_ins = std::pair<std::size_t, instruction_ref>;
struct compare_weight_ins
{
bool operator()(const weight_ins& x, const weight_ins& y) const
{
return std::make_pair(x.first, std::addressof(*x.second)) <
std::make_pair(y.first, std::addressof(*y.second));
}
};
void sort(program& p, std::size_t) const
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end());
auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1);
auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins);
auto& v = visited[ins];
auto it = children.find(std::make_pair(v * w, ins));
if(it == children.end())
{
v++;
children.insert(std::make_pair(v * w, ins));
}
};
add_child(last);
while(not children.empty())
{
// Pop the first element
auto top = children.begin()->second;
children.erase(children.begin());
p.move_instruction(top, p.begin());
for(auto ins : top->inputs())
{
add_child(ins);
}
}
}
......@@ -398,9 +468,10 @@ void schedule::apply(program& p) const
stream_info si;
auto last = std::prev(p.end());
si.accumulate_weights(last, model);
si.assign_streams(p, model.concurrency());
auto nstreams = si.assign_streams(p, model.concurrency());
si.sort(p, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
p.annotate(std::cout, [&](auto ins) {
std::cout << ":";
......@@ -417,6 +488,10 @@ void schedule::apply(program& p) const
std::cout << std::endl;
}
// No concurrency
if(nstreams < 2)
return;
// Schedule instructions
std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait;
......
......@@ -12,6 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_NSTREAMS)
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
......@@ -126,7 +127,7 @@ struct hip_device
struct context
{
context(std::size_t device_id = 0, std::size_t n = 4)
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n))
{
}
......
......@@ -97,11 +97,10 @@ static std::unordered_map<std::string, std::size_t> create_weight_map()
{
return {{"hip::load_literal", 0},
{"hip::allocate", 0},
{"gpu::convolution", 4},
{"gpu::conv_bias_relu", 4},
{"gpu::pooling", 2},
{"gpu::gemm", 2},
{"gpu::concat", 1}};
{"gpu::convolution", 8},
{"gpu::conv_bias_relu", 8},
{"gpu::pooling", 4},
{"gpu::gemm", 4}};
}
static const std::unordered_map<std::string, std::size_t>& weight_map()
......@@ -114,7 +113,7 @@ std::size_t schedule_model::weight(const operation& op) const
{
if(weight_map().count(op.name()) == 0)
{
return 1;
return 2;
}
return weight_map().at(op.name());
}
......
......@@ -30,7 +30,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const
{
......@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
dead_code_elimination{},
eliminate_workspace{},
......
......@@ -116,6 +116,7 @@ struct lhs_expression
TEST_LHS_REOPERATOR(%)
TEST_LHS_REOPERATOR(&)
TEST_LHS_REOPERATOR(|)
TEST_LHS_REOPERATOR (^)
TEST_LHS_REOPERATOR(&&)
TEST_LHS_REOPERATOR(||)
};
......
......@@ -280,7 +280,7 @@ TEST_CASE(stream_free)
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
EXPECT(not t.has_stream(binary));
}
TEST_CASE(zero_record)
......@@ -616,8 +616,9 @@ TEST_CASE(inner_par_merge)
t.get_stream(outer1),
t.get_stream(outer2)}));
EXPECT(t.get_stream(outer1) == 1);
EXPECT(t.get_stream(outer2) == 2);
EXPECT(t.get_stream(outer1) != t.get_stream(outer2));
EXPECT(migraphx::contains({1, 2}, t.get_stream(outer1)));
EXPECT(migraphx::contains({1, 2}, t.get_stream(outer2)));
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
......@@ -704,9 +705,8 @@ TEST_CASE(inner_split1)
EXPECT(
get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
// Either s1 or s2 has a wait depending on the sort order but not both
EXPECT(get_wait_for(s1).empty() xor get_wait_for(s2).empty());
t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment