Commit 417d6644 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv

parents 79e27dac 4a312201
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& prog) const void rewrite_pooling::apply(module& m) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "pooling") if(ins->name() != "pooling")
continue; continue;
...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const ...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction( auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
} }
// max pooling // max pooling
else else
{ {
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
} }
std::vector<int64_t> rsp_lens(lens.size(), 1); std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n; rsp_lens[0] = n;
rsp_lens[1] = c; rsp_lens[1] = c;
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
This diff is collapsed.
...@@ -42,7 +42,7 @@ struct stream_info ...@@ -42,7 +42,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::size_t> iweights; std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps; ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); } void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model) void accumulate_weights(instruction_ref last, const schedule_model& model)
{ {
...@@ -116,15 +116,15 @@ struct stream_info ...@@ -116,15 +116,15 @@ struct stream_info
} }
}; };
std::size_t assign_streams(module& p, std::size_t n) std::size_t assign_streams(module& m, std::size_t n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, p.end())); assert(not is_end(ins, m.end()));
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
return; return;
if(contains(partitions, ins)) if(contains(partitions, ins))
return; return;
...@@ -151,8 +151,8 @@ struct stream_info ...@@ -151,8 +151,8 @@ struct stream_info
} }
} }
// Sort instructions // Sort instructions
p.move_instruction(ins, p.end()); m.move_instruction(ins, m.end());
})(std::prev(p.end()), critical); })(std::prev(m.end()), critical);
// Set the critical partition to stream 0 // Set the critical partition to stream 0
set_stream(critical, 0); set_stream(critical, 0);
...@@ -197,13 +197,13 @@ struct stream_info ...@@ -197,13 +197,13 @@ struct stream_info
} }
}; };
void sort(module& p, std::size_t) void sort(module& m, std::size_t)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end()); auto last = std::prev(m.end());
auto mw = this->weights.at(last); auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1); auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) { auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1); auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins); auto w = x * this->iweights.at(ins);
...@@ -222,10 +222,10 @@ struct stream_info ...@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element // Pop the first element
auto top = children.begin()->second; auto top = children.begin()->second;
children.erase(children.begin()); children.erase(children.begin());
p.move_instruction(top, p.begin()); m.move_instruction(top, m.begin());
for(auto ins : top->inputs()) for(auto ins : top->inputs())
{ {
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
continue; continue;
add_child(ins); add_child(ins);
} }
...@@ -234,7 +234,7 @@ struct stream_info ...@@ -234,7 +234,7 @@ struct stream_info
{ {
for(auto ins : mod_implicit_deps.at(top)) for(auto ins : mod_implicit_deps.at(top))
{ {
assert(p.has_instruction(ins)); assert(m.has_instruction(ins));
add_child(ins); add_child(ins);
} }
} }
...@@ -242,12 +242,12 @@ struct stream_info ...@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed // move dangling parameter to the front so as not be removed
auto ins = std::next(last); auto ins = std::next(last);
while(ins != p.end()) while(ins != m.end())
{ {
auto next = std::next(ins); auto next = std::next(ins);
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
p.move_instruction(ins, p.begin()); m.move_instruction(ins, m.begin());
} }
ins = next; ins = next;
} }
...@@ -364,18 +364,18 @@ struct stream_info ...@@ -364,18 +364,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& p) const find_concurrent_instructions(module& m) const
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p); dominator_info di = compute_dominator(m);
result.reserve(p.size()); result.reserve(m.size());
merge_from.reserve(p.size()); merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(m))
{ {
for(auto&& arg : ins->outputs()) for(auto&& arg : ins->outputs())
{ {
if(not p.has_instruction(arg)) if(not m.has_instruction(arg))
continue; continue;
if(is_merge_point(arg)) if(is_merge_point(arg))
merge_from[ins].insert(arg); merge_from[ins].insert(arg);
...@@ -415,18 +415,18 @@ struct stream_info ...@@ -415,18 +415,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p) get_conflicts(module& m)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table; conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction // Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index; std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0; std::size_t index_total = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
ins2index[ins] = index_total++; ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables( std::vector<conflict_table_type> thread_conflict_tables(
...@@ -507,21 +507,21 @@ struct stream_info ...@@ -507,21 +507,21 @@ struct stream_info
} }
}; };
void schedule::apply(module& p) const void schedule::apply(module& m) const
{ {
if(not enable) if(not enable)
return; return;
stream_info si; stream_info si;
si.calc_implicit_deps(p); si.calc_implicit_deps(m);
auto last = std::prev(p.end()); auto last = std::prev(m.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(p, model.concurrency()); auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(p, model.concurrency()); si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{ {
p.annotate(std::cout, [&](auto ins) { m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins)) if(ins->name() == "@param" and not contains(si.weights, ins))
return; return;
...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const ...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std::unordered_map<instruction_ref, std::size_t> ins2wait; std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for; std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited; std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size()); ins2wait.reserve(m.size());
ins2waited.reserve(p.size()); ins2waited.reserve(m.size());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
if(not si.has_stream(ins)) if(not si.has_stream(ins))
...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const ...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream // Schedule instruction on the stream
auto stream = si.get_stream(ins); auto stream = si.get_stream(ins);
assert(stream < model.concurrency()); assert(stream < model.concurrency());
model.sched(p, ins, stream); model.sched(m, ins, stream);
// Insert wait instructions // Insert wait instructions
if(si.is_merge_point(ins, stream)) if(si.is_merge_point(ins, stream))
{ {
...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const ...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if(not contains(ins2wait, i)) if(not contains(ins2wait, i))
{ {
ins2wait[i] = wait_id; ins2wait[i] = wait_id;
model.record(p, i, wait_id); model.record(m, i, wait_id);
wait_id++; wait_id++;
} }
auto w = ins2wait.at(i); auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont // If we already waited for the event on this stream then dont
// insert another wait event // insert another wait event
if(not contains(waited_for[stream], w)) if(not contains(waited_for[stream], w))
model.wait(p, ins, w); model.wait(m, ins, w);
// Store the event as waited // Store the event as waited
waited_for[stream].insert(w); waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction // Store all wait events that have been waited on prior to the recorded instruction
...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const ...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
} }
// Add memory conflicts // Add memory conflicts
auto conflict_table = si.get_conflicts(p); auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table) for(auto&& ip : conflict_table)
{ {
if(ip.second.empty()) if(ip.second.empty())
...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const ...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.push_back(ip.first); args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end()); args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), make_op("identity"), args); m.insert_instruction(std::next(ip.first), make_op("identity"), args);
} }
} }
......
This diff is collapsed.
...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops ...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2"))); match::arg(1)(dequantizelinear_op("x2", "scale2")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto q1 = r.instructions["x1"];
......
...@@ -70,19 +70,19 @@ struct find_reshaper ...@@ -70,19 +70,19 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front())); assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front(); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input); reshapes.push_back(input);
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
...@@ -96,7 +96,7 @@ struct find_reshaper ...@@ -96,7 +96,7 @@ struct find_reshaper
} }
if(r.first != r.second) if(r.first != r.second)
{ {
p.replace_instruction(r.first, r.second); m.replace_instruction(r.first, r.second);
} }
} }
}; };
...@@ -117,10 +117,10 @@ struct find_nop_reshapes ...@@ -117,10 +117,10 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); m.replace_instruction(ins, ins->inputs().front());
} }
}; };
...@@ -132,7 +132,7 @@ struct find_transpose ...@@ -132,7 +132,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -149,11 +149,11 @@ struct find_transpose ...@@ -149,11 +149,11 @@ struct find_transpose
return; return;
if(is_no_transpose(dims)) if(is_no_transpose(dims))
{ {
p.replace_instruction(ins, t->inputs().front()); m.replace_instruction(ins, t->inputs().front());
} }
else else
{ {
p.replace_instruction( m.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front()); ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
} }
} }
...@@ -223,7 +223,7 @@ struct find_nested_slice ...@@ -223,7 +223,7 @@ struct find_nested_slice
return result; return result;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto slice = ins->inputs().front(); auto slice = ins->inputs().front();
...@@ -241,7 +241,7 @@ struct find_nested_slice ...@@ -241,7 +241,7 @@ struct find_nested_slice
op.starts.push_back(pp.second.first); op.starts.push_back(pp.second.first);
op.ends.push_back(pp.second.second); op.ends.push_back(pp.second.second);
} }
p.replace_instruction(ins, op, input); m.replace_instruction(ins, op, input);
} }
}; };
...@@ -252,7 +252,7 @@ struct find_concat_transpose ...@@ -252,7 +252,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto trans_inputs = ins->inputs(); auto trans_inputs = ins->inputs();
...@@ -279,14 +279,14 @@ struct find_concat_transpose ...@@ -279,14 +279,14 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction( return m.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i); ins, make_op("transpose", {{"permutation", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = m.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction( auto t = m.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat); ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); m.replace_instruction(ins, t);
} }
}; };
...@@ -303,7 +303,7 @@ struct find_nested_concat ...@@ -303,7 +303,7 @@ struct find_nested_concat
return op.axis; return op.axis;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto axis = get_axis(ins); auto axis = get_axis(ins);
...@@ -317,7 +317,7 @@ struct find_nested_concat ...@@ -317,7 +317,7 @@ struct find_nested_concat
args.push_back(i); args.push_back(i);
} }
})(ins->inputs()); })(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
...@@ -329,7 +329,7 @@ struct find_resize ...@@ -329,7 +329,7 @@ struct find_resize
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_rsp = r.instructions["data"]; auto ins_rsp = r.instructions["data"];
...@@ -417,13 +417,13 @@ struct find_resize ...@@ -417,13 +417,13 @@ struct find_resize
} }
auto in_rsp = ins_rsp->inputs().front(); auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction( auto rsp_data = m.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction( auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
} }
}; };
...@@ -436,7 +436,7 @@ struct find_where_op ...@@ -436,7 +436,7 @@ struct find_where_op
match::is_constant().bind("ind"))); match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto concat = r.instructions["data"]; auto concat = r.instructions["data"];
...@@ -475,11 +475,11 @@ struct find_where_op ...@@ -475,11 +475,11 @@ struct find_where_op
if(val) if(val)
{ {
p.replace_instruction(ins, inputs.at(0)); m.replace_instruction(ins, inputs.at(0));
} }
else else
{ {
p.replace_instruction(ins, inputs.at(1)); m.replace_instruction(ins, inputs.at(1));
} }
} }
}; };
...@@ -496,7 +496,7 @@ struct find_reshape_cont ...@@ -496,7 +496,7 @@ struct find_reshape_cont
match::any())); match::any()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_cont = r.instructions["cont"]; auto ins_cont = r.instructions["cont"];
...@@ -530,11 +530,11 @@ struct find_reshape_cont ...@@ -530,11 +530,11 @@ struct find_reshape_cont
else else
{ {
inputs.push_back( inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in)); m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
} }
} }
auto out = p.insert_instruction(ins, ins->get_operator(), inputs); auto out = m.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out); m.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
} }
}; };
...@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match::args(match_transpose_contiguous_reshaper())); match::args(match_transpose_contiguous_reshaper()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"]; auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"]; auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"]; auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name(); auto unary_op_name = ins->get_operator().name();
auto unary_ins = p.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = p.insert_instruction(cont_ins, make_op("contiguous"), unary_ins); auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination // older cont and reshape are removed by deadcode elimination
p.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
} }
}; };
void simplify_reshapes::apply(module& p) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
match::find_matches(p, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{}, find_reshape_cont{},
...@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const ...@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(m);
} }
} }
......
...@@ -352,7 +352,7 @@ struct cpu_apply ...@@ -352,7 +352,7 @@ struct cpu_apply
std::transform(bind_inputs.begin(), std::transform(bind_inputs.begin(),
bind_inputs.end(), bind_inputs.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); }); [&](const auto& s) { return r.instructions[s]; });
inputs.push_back(this->insert_allocation(ins, ins->get_shape())); inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
modl->replace_instruction(ins, op, inputs); modl->replace_instruction(ins, op, inputs);
}); });
......
...@@ -158,6 +158,7 @@ add_library(migraphx_gpu ...@@ -158,6 +158,7 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp pad.cpp
pooling.cpp pooling.cpp
quant_convolution.cpp quant_convolution.cpp
......
...@@ -28,30 +28,30 @@ struct hip_stream_model ...@@ -28,30 +28,30 @@ struct hip_stream_model
bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; } bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; }
}; };
stream_model make_stream_model(const module& p) stream_model make_stream_model(const module& m)
{ {
hip_stream_model m; hip_stream_model hsm;
std::size_t stream = 0; std::size_t stream = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::set_stream") if(ins->name() == "gpu::set_stream")
{ {
auto v = ins->get_operator().to_value(); auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>(); stream = v["stream"].to<std::size_t>();
m.max_stream = std::max(stream, m.max_stream); hsm.max_stream = std::max(stream, hsm.max_stream);
} }
if(ins->get_operator().is_context_free()) if(ins->get_operator().is_context_free())
continue; continue;
if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name())) if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name()))
continue; continue;
m.ins2stream[ins] = stream; hsm.ins2stream[ins] = stream;
} }
return m; return hsm;
} }
std::vector<stream_race> analyze_streams(const module& p) std::vector<stream_race> analyze_streams(const module& m)
{ {
return migraphx::analyze_streams(p, make_stream_model(p)); return migraphx::analyze_streams(m, make_stream_model(m));
} }
} // namespace gpu } // namespace gpu
......
...@@ -22,6 +22,7 @@ namespace gpu { ...@@ -22,6 +22,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #if MIGRAPHX_USE_HIPRTC
...@@ -247,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -247,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco"); MIGRAPHX_THROW("Missing hsaco");
}; };
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{})) if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{ {
......
...@@ -11,11 +11,11 @@ namespace migraphx { ...@@ -11,11 +11,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void eliminate_workspace::apply(module& p) const void eliminate_workspace::apply(module& m) const
{ {
std::size_t n = 0; std::size_t n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const ...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); auto ws = m.add_parameter("workspace", shape{shape::int8_type, {n}});
for(auto&& a : allocs) for(auto&& a : allocs)
{ {
p.replace_instruction(a, ws); m.replace_instruction(a, ws);
p.remove_instruction(a); m.remove_instruction(a);
} }
} }
} }
......
...@@ -316,7 +316,7 @@ struct find_layernorm ...@@ -316,7 +316,7 @@ struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -331,7 +331,7 @@ struct find_layernorm ...@@ -331,7 +331,7 @@ struct find_layernorm
if(relements > 1024 or (relements % 4 != 0 and relements > 256)) if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return; return;
p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back()); m.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
} }
}; };
...@@ -343,11 +343,11 @@ struct find_triadd_layernorm ...@@ -343,11 +343,11 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto triadd = ins->inputs().front(); auto triadd = ins->inputs().front();
p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs()); m.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
} }
}; };
...@@ -355,13 +355,13 @@ struct find_gelu ...@@ -355,13 +355,13 @@ struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); } auto matcher() const { return match::gelu_erf(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
} }
}; };
...@@ -372,7 +372,7 @@ struct find_add_gelu ...@@ -372,7 +372,7 @@ struct find_add_gelu
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -381,7 +381,7 @@ struct find_add_gelu ...@@ -381,7 +381,7 @@ struct find_add_gelu
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu{}, args); m.replace_instruction(ins, hip_add_gelu{}, args);
} }
}; };
...@@ -391,16 +391,16 @@ struct find_gelu_new ...@@ -391,16 +391,16 @@ struct find_gelu_new
auto matcher() const { return match::gelu_tanh(&gpu_name); } auto matcher() const { return match::gelu_tanh(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
if(fast_math) if(fast_math)
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
else else
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
} }
}; };
...@@ -411,7 +411,7 @@ struct find_add_gelu_new ...@@ -411,7 +411,7 @@ struct find_add_gelu_new
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -420,7 +420,7 @@ struct find_add_gelu_new ...@@ -420,7 +420,7 @@ struct find_add_gelu_new
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu_new{}, args); m.replace_instruction(ins, hip_add_gelu_new{}, args);
} }
}; };
...@@ -435,7 +435,7 @@ struct find_add_clip ...@@ -435,7 +435,7 @@ struct find_add_clip
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -448,9 +448,9 @@ struct find_add_clip ...@@ -448,9 +448,9 @@ struct find_add_clip
add_args.pop_back(); add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end()); add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{}, add_args); m.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, hip_triadd_clip{}, add_args); m.replace_instruction(ins, hip_triadd_clip{}, add_args);
} }
}; };
...@@ -470,7 +470,7 @@ struct find_add_unary ...@@ -470,7 +470,7 @@ struct find_add_unary
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -481,9 +481,9 @@ struct find_add_unary ...@@ -481,9 +481,9 @@ struct find_add_unary
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, binary_add_op, args); m.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, ternary_add_op, args); m.replace_instruction(ins, ternary_add_op, args);
} }
}; };
...@@ -498,7 +498,7 @@ struct find_triadd ...@@ -498,7 +498,7 @@ struct find_triadd
.bind("input"))); .bind("input")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
...@@ -513,7 +513,7 @@ struct find_triadd ...@@ -513,7 +513,7 @@ struct find_triadd
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); m.replace_instruction(ins, hip_triadd{}, args);
} }
}; };
...@@ -525,7 +525,7 @@ struct find_mul_add ...@@ -525,7 +525,7 @@ struct find_mul_add
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_ins = r.instructions["mul"]; auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
...@@ -538,7 +538,7 @@ struct find_mul_add ...@@ -538,7 +538,7 @@ struct find_mul_add
args.insert(std::prev(args.end()), b_ins); args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args); m.replace_instruction(ins, hip_mul_add{}, args);
} }
}; };
...@@ -550,7 +550,7 @@ struct find_mul_add_relu ...@@ -550,7 +550,7 @@ struct find_mul_add_relu
match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
...@@ -558,7 +558,7 @@ struct find_mul_add_relu ...@@ -558,7 +558,7 @@ struct find_mul_add_relu
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args); m.replace_instruction(ins, hip_mul_add_relu{}, args);
} }
}; };
...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms) ...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms)
} }
template <class Op> template <class Op>
void apply_conv_bias(context& ctx, module& p, match::matcher_result r) void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.get_workspace(ctx); auto ws = cb.get_workspace(ctx);
(void)ws; (void)ws;
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT inline auto precompile_name(std::string s) // NOLINT
...@@ -829,9 +829,9 @@ struct find_conv_bias ...@@ -829,9 +829,9 @@ struct find_conv_bias
match::output(match::name(std::unordered_set<std::string>{"gpu::relu"})))); match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias>(*ctx, m, r);
} }
}; };
...@@ -840,9 +840,9 @@ struct find_conv_bias_relu ...@@ -840,9 +840,9 @@ struct find_conv_bias_relu
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, m, r);
} }
}; };
...@@ -857,7 +857,7 @@ struct find_conv_pointwise ...@@ -857,7 +857,7 @@ struct find_conv_pointwise
fusable_conv(match::used_once()).bind("conv"))); fusable_conv(match::used_once()).bind("conv")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -896,7 +896,7 @@ struct find_gemm_add ...@@ -896,7 +896,7 @@ struct find_gemm_add
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
...@@ -919,15 +919,15 @@ struct find_gemm_add ...@@ -919,15 +919,15 @@ struct find_gemm_add
auto copy_ins = c_ins; auto copy_ins = c_ins;
// Insert copy // Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) if(ins == m.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{ {
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); copy_ins = m.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
} }
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
gemm.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
...@@ -938,22 +938,22 @@ struct find_commutative_broadcast ...@@ -938,22 +938,22 @@ struct find_commutative_broadcast
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape())); return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto args = ins->inputs(); auto args = ins->inputs();
move_broadcasted_back(args); move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
void fuse_ops::apply(module& p) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_gelu{}, find_gelu_new{fast_math});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(p, match::find_matches(m,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx}, find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const ...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,7 +11,7 @@ struct module; ...@@ -11,7 +11,7 @@ struct module;
namespace gpu { namespace gpu {
std::vector<stream_race> analyze_streams(const module& p); std::vector<stream_race> analyze_streams(const module& m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ namespace gpu { ...@@ -14,7 +14,7 @@ namespace gpu {
struct eliminate_workspace struct eliminate_workspace
{ {
std::string name() const { return "eliminate_workspace"; } std::string name() const { return "eliminate_workspace"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct fuse_ops ...@@ -16,7 +16,7 @@ struct fuse_ops
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true; bool fast_math = true;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
...@@ -17,9 +17,9 @@ struct schedule_model ...@@ -17,9 +17,9 @@ struct schedule_model
{ {
std::size_t streams = 0; std::size_t streams = 0;
std::size_t concurrency() const; std::size_t concurrency() const;
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
......
...@@ -15,7 +15,7 @@ namespace gpu { ...@@ -15,7 +15,7 @@ namespace gpu {
struct sync_device struct sync_device
{ {
std::string name() const { return "sync_device"; } std::string name() const { return "sync_device"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ struct write_literals ...@@ -14,7 +14,7 @@ struct write_literals
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::write_literals"; } std::string name() const { return "gpu::write_literals"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -26,9 +27,10 @@ namespace migraphx { ...@@ -26,9 +27,10 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void ${kernel}(${params})
{ {
pointwise(${lambda}, ${args}); auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
} }
} }
...@@ -37,44 +39,123 @@ __global__ void kernel(${params}) ...@@ -37,44 +39,123 @@ __global__ void kernel(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs) static std::size_t oversubscribe_if(bool b)
{ {
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); })) if(b)
return 1;
else
return 256; return 256;
else
return 1;
} }
static std::size_t vectorize_elements(const std::vector<shape>& inputs) static std::size_t find_fast_axis(const std::vector<shape>& inputs)
{ {
std::size_t n = inputs.front().elements(); auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
static std::vector<bool> preload(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return result;
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return result;
}
static std::string preload_str(const std::vector<bool>& bs)
{
std::vector<std::string> bool_strs;
std::transform(bs.begin(), std::prev(bs.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "false, " + join_strings(bool_strs, ", ");
}
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs is half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted(); return s.type() == shape::half_type;
})) }))
{ return {2};
if((n % 4) == 0) return {4, 2};
return n / 4;
else if((n % 2) == 0)
return n / 2;
} }
return n; static auto vectorize_elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return *std::min_element(max_vec_size.begin(), max_vec_size.end());
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
{"axis", std::to_string(axis)},
{"preloads", preload_str(preloads)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -100,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto name = g.create_function( auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace( return replace(
compile_op(ctx, to_shapes(ins->inputs()), {{"lambda", lambda}, {"preamble", g.str()}})); compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment