Commit 2669e8c1 authored by Paul's avatar Paul
Browse files

Formatting

parent ffc51fe7
...@@ -29,7 +29,7 @@ struct stream_info ...@@ -29,7 +29,7 @@ struct stream_info
if(weights.count(ins) == 0) if(weights.count(ins) == 0)
{ {
std::size_t weight = 0; std::size_t weight = 0;
auto&& op = ins->get_operator(); auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@') if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op); weight = model.weight(op);
weights[ins] = weights[ins] =
...@@ -99,20 +99,20 @@ struct stream_info ...@@ -99,20 +99,20 @@ struct stream_info
return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); }); return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
} }
template<class F> template <class F>
bool different(F f) const bool different(F f) const
{ {
bool first = true; bool first = true;
std::size_t stream = 0; std::size_t stream = 0;
bool result = false; bool result = false;
f([&](auto s) { f([&](auto s) {
if (not first and s != stream) if(not first and s != stream)
{ {
result = true; result = true;
return false; return false;
} }
stream = s; stream = s;
first = false; first = false;
return true; return true;
}); });
return result; return result;
...@@ -127,12 +127,12 @@ struct stream_info ...@@ -127,12 +127,12 @@ struct stream_info
{ {
if(weights.at(i) == 0) if(weights.at(i) == 0)
{ {
if (not self(i)) if(not self(i))
return false; return false;
} }
else else
{ {
if (not f(get_stream(i))) if(not f(get_stream(i)))
return false; return false;
} }
} }
...@@ -170,7 +170,8 @@ struct stream_info ...@@ -170,7 +170,8 @@ struct stream_info
return result; return result;
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> find_concurrent_instructions(program& p) std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p)
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
...@@ -199,8 +200,8 @@ struct stream_info ...@@ -199,8 +200,8 @@ struct stream_info
for(auto& split : split_from[ins]) for(auto& split : split_from[ins])
{ {
auto stream = get_stream(ins); auto stream = get_stream(ins);
if (result[split].size() <= stream) if(result[split].size() <= stream)
result[split].resize(stream+1); result[split].resize(stream + 1);
result[split][stream].push_back(ins); result[split][stream].push_back(ins);
} }
} }
...@@ -237,19 +238,19 @@ void schedule::apply(program& p) const ...@@ -237,19 +238,19 @@ void schedule::apply(program& p) const
// Add memory conflicts // Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p); auto concur_ins = si.find_concurrent_instructions(p);
for(auto&& split:concur_ins) for(auto&& split : concur_ins)
{ {
dfor(split.second.size(), split.second.size())([&](auto i, auto j) { dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
if (i == j) if(i == j)
return; return;
for(auto ins1:split.second[i]) for(auto ins1 : split.second[i])
{ {
auto idx1 = std::distance(split.first, ins1); auto idx1 = std::distance(split.first, ins1);
for(auto ins2:split.second[j]) for(auto ins2 : split.second[j])
{ {
if (ins1 == ins2) if(ins1 == ins2)
continue; continue;
auto idx2 = std::distance(split.first, ins2); auto idx2 = std::distance(split.first, ins2);
auto point = idx1 > idx2 ? ins1 : ins2; auto point = idx1 > idx2 ? ins1 : ins2;
p.insert_instruction(std::next(point), op::identity{}, ins1, ins2); p.insert_instruction(std::next(point), op::identity{}, ins1, ins2);
} }
......
...@@ -33,17 +33,21 @@ struct wait_event ...@@ -33,17 +33,21 @@ struct wait_event
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{ {
assert(event != nullptr); assert(event != nullptr);
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); })); assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) {
return i == ctx.get_current_device().stream_id();
}));
for(auto n : wait_for) for(auto n : wait_for)
ctx.get_stream(n).record(event.get()); ctx.get_stream(n).record(event.get());
ctx.get_stream().wait(event.get()); ctx.get_stream().wait(event.get());
return {}; return {};
} }
void finalize(context& ctx, const shape&, std::vector<shape>) void finalize(context& ctx, const shape&, std::vector<shape>)
{ {
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); })); assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) {
event = create_event(); return i == ctx.get_current_device().stream_id();
}));
event = create_event();
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment