"...resnet50_tensorflow.git" did not exist on "1edd6e8613af13f816120e56be610d1eb9a08cd1"
Commit 3b800ff3 authored by Paul's avatar Paul
Browse files

Avoid waits on zero-weighted instructions

parent 9b14e317
...@@ -32,6 +32,7 @@ struct stream_info ...@@ -32,6 +32,7 @@ struct stream_info
{ {
std::unordered_map<instruction_ref, std::size_t> ins2stream; std::unordered_map<instruction_ref, std::size_t> ins2stream;
std::unordered_map<instruction_ref, std::size_t> weights; std::unordered_map<instruction_ref, std::size_t> weights;
std::unordered_map<instruction_ref, std::size_t> iweights;
void accumulate_weights(instruction_ref last, const schedule_model& model) void accumulate_weights(instruction_ref last, const schedule_model& model)
{ {
...@@ -42,6 +43,7 @@ struct stream_info ...@@ -42,6 +43,7 @@ struct stream_info
auto&& op = ins->get_operator(); auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@') if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op); weight = model.weight(op);
iweights[ins] = weight;
weights[ins] = weights[ins] =
std::accumulate(ins->inputs().begin(), std::accumulate(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
...@@ -55,14 +57,14 @@ struct stream_info ...@@ -55,14 +57,14 @@ struct stream_info
void assign_streams(program& p, std::size_t streams) void assign_streams(program& p, std::size_t streams)
{ {
const std::size_t min_partition_threshold = 2; const std::size_t min_partition_threshold = 2;
for(std::size_t stream = 0; stream < streams; stream++) for(std::size_t stream = 0; stream < streams-1; stream++)
{ {
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
// If weight is zero then stop // If weight is zero then stop
if(this->weights[ins] == 0) if(this->weights[ins] == 0)
return; return;
// Only assign streams if not already assigned // Only assign streams if not already assigned
if(not this->has_stream(ins)) if(not this->has_stream(ins) and this->iweights[ins] > 0)
this->set_stream(ins, stream); this->set_stream(ins, stream);
instruction_ref child = p.end(); instruction_ref child = p.end();
std::size_t w = 0; std::size_t w = 0;
...@@ -90,13 +92,17 @@ struct stream_info ...@@ -90,13 +92,17 @@ struct stream_info
{ {
if(has_stream(ins)) if(has_stream(ins))
continue; continue;
if(weights[ins] == 0) if(iweights[ins] == 0)
continue; continue;
set_stream(ins, streams - 1); set_stream(ins, streams - 1);
} }
} }
void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; } void set_stream(instruction_ref ins, std::size_t n)
{
assert(iweights[ins] > 0);
ins2stream[ins] = n;
}
std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); } std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
...@@ -143,7 +149,7 @@ struct stream_info ...@@ -143,7 +149,7 @@ struct stream_info
return fix<bool>([&](auto self, auto ins) { return fix<bool>([&](auto self, auto ins) {
for(auto i : select(ins)) for(auto i : select(ins))
{ {
if(weights.at(i) == 0) if(iweights.at(i) == 0)
{ {
if(not self(i)) if(not self(i))
return false; return false;
...@@ -195,7 +201,7 @@ struct stream_info ...@@ -195,7 +201,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(weights[ins] == 0) if(iweights[ins] == 0)
continue; continue;
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
...@@ -238,12 +244,13 @@ void schedule::apply(program& p) const ...@@ -238,12 +244,13 @@ void schedule::apply(program& p) const
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
auto args = ins->inputs(); auto args = ins->inputs();
std::sort(args.begin(), args.end(), [&](auto x, auto y) { std::sort(args.begin(), args.end(), [&](auto x, auto y) {
return si.weights[x] < si.weights[y]; return std::make_tuple(si.weights[x], x->inputs().size()) < std::make_tuple(si.weights[y], y->inputs().size());
}); });
for(auto i : args) for(auto i : args)
{
p.move_instruction(i, p.begin()); p.move_instruction(i, p.begin());
for(auto i : args)
self(i); self(i);
}
})(last); })(last);
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
...@@ -265,6 +272,7 @@ void schedule::apply(program& p) const ...@@ -265,6 +272,7 @@ void schedule::apply(program& p) const
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
if(not si.has_stream(ins)) if(not si.has_stream(ins))
continue; continue;
assert(si.weights[ins] > 0);
// Schedule instruction on the stream // Schedule instruction on the stream
auto stream = si.get_stream(ins); auto stream = si.get_stream(ins);
assert(stream < model.concurrency()); assert(stream < model.concurrency());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment