"vscode:/vscode.git/clone" did not exist on "1371e66ed990b038492cd8a09fd6b00d709e0fce"
Unverified Commit b8deb54c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into fuse-horiz-contiguous

parents fee84355 ca8a54fe
...@@ -56,14 +56,21 @@ struct nonmaxsuppression ...@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens(); auto lens = inputs.front().lens();
// check input shape // check input shape
if(lens[1] != inputs.at(1).lens()[2]) if(lens[1] != inputs.at(1).lens()[2])
{ {
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!"); MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
} }
std::vector<int64_t> out_lens(2); std::vector<int64_t> out_lens(2);
...@@ -74,8 +81,8 @@ struct nonmaxsuppression ...@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct box struct box
{ {
std::array<float, 2> x; std::array<double, 2> x;
std::array<float, 2> y; std::array<double, 2> y;
void sort() void sort()
{ {
...@@ -83,9 +90,9 @@ struct nonmaxsuppression ...@@ -83,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end()); std::sort(y.begin(), y.end());
} }
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end())); assert(std::is_sorted(y.begin(), y.end()));
...@@ -94,29 +101,29 @@ struct nonmaxsuppression ...@@ -94,29 +101,29 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box batch_box(const T* boxes, std::size_t bidx) const box batch_box(T boxes, std::size_t box_idx) const
{ {
box result{}; box result{};
const T* start = boxes + 4 * bidx; auto start = boxes + 4 * box_idx;
if(center_point_box) if(center_point_box)
{ {
float half_width = start[2] / 2.0f; double half_width = start[2] / 2.0;
float half_height = start[3] / 2.0f; double half_height = start[3] / 2.0;
float x_center = start[0]; double x_center = start[0];
float y_center = start[1]; double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width}; result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height}; result.y = {y_center - half_height, y_center + half_height};
} }
else else
{ {
result.x = {start[1], start[3]}; result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {start[0], start[2]}; result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
} }
return result; return result;
} }
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{ {
b1.sort(); b1.sort();
b2.sort(); b2.sort();
...@@ -128,7 +135,7 @@ struct nonmaxsuppression ...@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} }
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y}; std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end()); return not std::is_sorted(bx.begin(), bx.end());
})) }))
...@@ -136,115 +143,124 @@ struct nonmaxsuppression ...@@ -136,115 +143,124 @@ struct nonmaxsuppression
return false; return false;
} }
const float area1 = b1.area(); const double area1 = b1.area();
const float area2 = b2.area(); const double area2 = b2.area();
const float intersection_area = intersection.area(); const double intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area; const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f) if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{ {
return false; return false;
} }
const float intersection_over_union = intersection_area / union_area; const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const // filter boxes below score_threshold
template <class T>
std::priority_queue<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
argument result{output_shape}; std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap =
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); }); make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
std::size_t max_output_boxes_per_class = 0; transform_if(
float iou_threshold = 0.0f; scores_start,
float score_threshold = 0.0f; scores_start + num_boxes,
insert_to_boxes_heap,
if(args.size() > 2) [&](auto sc) {
{ box_idx++;
max_output_boxes_per_class = args.at(2).at<std::size_t>(); return sc >= score_threshold;
} },
// max_output_boxes_per_class is 0, no output [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
if(max_output_boxes_per_class == 0) return boxes_heap;
{ }
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class; template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>()); shape comp_s{shape::double_type, {num_batches, num_classes}};
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0]; auto batch_idx = idx[0];
auto cidx = idx[1]; auto class_idx = idx[1];
// index offset for this class
std::size_t score_offset = (bidx * class_num + cidx) * box_num; auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
const float* batch_boxes = boxes + bidx * box_num * 4; // iterator to first value of this batch
std::priority_queue<std::pair<float, int64_t>> sorted_boxes; auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto insert_to_sorted_boxes = auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() && while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top(); // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
// Check with existing selected boxes for this class, suppress if exceed the IOU const auto next_top_score = boxes_heap.top();
// (Intersection Over Union) threshold bool not_selected =
bool not_selected = std::any_of( std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.begin(), selected_boxes_inside_class.end(),
selected_boxes_inside_class.end(), [&](auto selected_index) {
[&](auto selected_index) { return this->suppress_by_iou(
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second), batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes, selected_index.second), batch_box(batch_boxes_start, selected_index.second),
iou_threshold); iou_threshold);
}); });
if(not not_selected) if(not not_selected)
{ {
selected_boxes_inside_class.push_back(next_top_score); selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx); selected_indices.push_back(batch_idx);
selected_indices.push_back(cidx); selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second); selected_indices.push_back(next_top_score.second);
} }
sorted_boxes.pop(); boxes_heap.pop();
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) { std::size_t max_output_boxes_per_class =
std::copy(selected_indices.begin(), selected_indices.end(), out.begin()); (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0)
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
}); });
return result; return result;
......
...@@ -38,6 +38,7 @@ struct module_pass_manager ...@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete; module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
protected: protected:
......
...@@ -132,6 +132,8 @@ struct program ...@@ -132,6 +132,8 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
std::unordered_multimap<module_ref, module_ref> get_module_tree();
void remove_module(const std::string& name); void remove_module(const std::string& name);
void remove_unused_modules(); void remove_unused_modules();
......
...@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f) ...@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std::transform(r.begin(), r.end(), it, f); std::transform(r.begin(), r.end(), it, f);
} }
template <class Range1, class Range2, class Iterator, class F>
void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
{
std::transform(r1.begin(), r1.end(), r2.begin(), it, f);
}
template <class Range> template <class Range>
auto reverse(Range& r) auto reverse(Range& r)
{ {
...@@ -210,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x) ...@@ -210,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R1, class R2> template <class R1, class R2, class... Predicate>
bool equal(R1&& r1, R2&& r2) bool equal(R1&& r1, R2&& r2, Predicate... pred)
{
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
}
template <class Range>
auto distance(Range&& r)
{ {
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end()); return std::distance(r.begin(), r.end());
} }
template <class R> template <class R>
......
...@@ -191,6 +191,10 @@ struct shape ...@@ -191,6 +191,10 @@ struct shape
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; } std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
auto is_integral() const { return std::is_integral<type>{}; }
auto is_signed() const { return std::is_signed<type>{}; }
auto is_unsigned() const { return std::is_unsigned<type>{}; }
template <class U> template <class U>
type* from(U* buffer, std::size_t n = 0) const type* from(U* buffer, std::size_t n = 0) const
{ {
......
...@@ -44,8 +44,8 @@ auto with_char(F f) ...@@ -44,8 +44,8 @@ auto with_char(F f)
return [=](unsigned char c) -> bool { return f(c); }; return [=](unsigned char c) -> bool { return f(c); };
} }
inline std::string inline void
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string_inplace(std::string& subject, const std::string& search, const std::string& replace)
{ {
size_t pos = 0; size_t pos = 0;
while((pos = subject.find(search, pos)) != std::string::npos) while((pos = subject.find(search, pos)) != std::string::npos)
...@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string ...@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
subject.replace(pos, search.length(), replace); subject.replace(pos, search.length(), replace);
pos += replace.length(); pos += replace.length();
} }
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
{
replace_string_inplace(subject, search, replace);
return subject; return subject;
} }
......
...@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond) ...@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{ {
const auto& mod_inputs = ins->module_inputs(); const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1); module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod); auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs(); auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size()); assert(mod_outputs.size() >= ins_outputs.size());
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -196,6 +197,62 @@ void module::assign(const module& m) ...@@ -196,6 +197,62 @@ void module::assign(const module& m)
} }
} }
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args) instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{ {
return insert_instruction(impl->instructions.end(), op, std::move(args)); return insert_instruction(impl->instructions.end(), op, std::move(args));
...@@ -334,53 +391,52 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d ...@@ -334,53 +391,52 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src; return src;
} }
std::vector<instruction_ref> module::insert_module_instructions( std::vector<instruction_ref>
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins) module::add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
std::vector<instruction_ref> mod_outputs; return this->insert_instructions(this->end(), instructions, std::move(map_ins));
for(auto sins : iterator_for(*m)) }
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return") std::vector<instruction_ref>
{ module::add_instructions(const_module_ref m,
mod_outputs = copy_inputs; std::unordered_map<instruction_ref, instruction_ref> map_ins)
break; {
} return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args); std::vector<instruction_ref>
} module::add_instructions(instruction_ref start,
map_ins[sins] = copy_ins; instruction_ref last,
} std::unordered_map<instruction_ref, instruction_ref> map_ins)
if(mod_outputs.empty()) {
mod_outputs = {map_ins.at(std::prev(m->end()))}; return this->insert_instructions(this->end(), start, last, std::move(map_ins));
return mod_outputs; }
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
} }
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l)
...@@ -706,44 +762,33 @@ void module::print_graph(std::ostream& os, bool brief) const ...@@ -706,44 +762,33 @@ void module::print_graph(std::ostream& os, bool brief) const
os << "}" << std::endl; os << "}" << std::endl;
} }
static std::string cpp_var_name(const std::string& name) static std::string to_c_id(const std::string& name, char rep = '_')
{ {
return "m" + replace_string(name, "@", "x"); std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(contains(id, "__"))
replace_string_inplace(id, "__", "_");
return id;
} }
static std::string cpp_op_var(const std::string& name, instruction_ref ins) static std::string cpp_var_name(const std::string& name)
{ {
return replace_string(name, "@", ins->name()); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
std::string x = to_string(op); os << "migraphx::make_op(" << enclose_name(op.name());
if(contains(x, "[")) auto v = op.to_value();
if(not v.empty())
{ {
auto start = x.find('['); os << ", "
auto end = x.find(']'); << "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")";
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
} }
os << ")";
} }
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...@@ -756,22 +801,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -756,22 +801,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
} }
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const module::print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; // cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print( names = this->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins); std::vector<std::string> input_vars;
if(ins->name().front() != '@') std::transform(ins->inputs().begin(),
{ ins->inputs().end(),
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl; std::back_inserter(input_vars),
print_op_attributes(os, op, ins->get_operator()); [&](auto input) { return cpp_var_name(ins_names.at(input)); });
} if(ins != last)
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = "; os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
os << "p.add_literal("; os << mname << "->add_literal(";
bool use_abs = false; bool use_abs = false;
ins->get_literal().visit([&](auto v) { ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; }); use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
...@@ -789,17 +837,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str ...@@ -789,17 +837,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter; std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ","; os << mname << "->add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape()); print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl; os << ");" << std::endl;
} }
else if(ins->name() == "@return")
{
os << mname << "->add_return({";
os << join_strings(input_vars, ", ");
os << "});" << std::endl;
}
else else
{ {
os << "p.add_instruction(" << op; assert(ins->name().front() != '@');
for(auto input : ins->inputs()) os << mname << "->add_instruction(";
{ print_make_op(os, ins->get_operator());
os << ", " << cpp_var_name(ins_names.at(input)); os << ", " << join_strings(input_vars, ", ");
}
os << ");" << std::endl; os << ");" << std::endl;
} }
}, },
...@@ -808,7 +861,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str ...@@ -808,7 +861,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
return names; return names;
} }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
...@@ -819,17 +872,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) ...@@ -819,17 +872,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
}); });
} }
std::vector<module_ref> module::get_sub_modules() const std::vector<module_ref> module::get_sub_modules(bool shallow) const
{ {
std::vector<module_ref> vec_modules; std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end()); vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
for(const auto& smod : mod_args) if(not shallow)
{ {
auto sub_mods = smod->get_sub_modules(); for(const auto& smod : mod_args)
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end()); {
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
} }
} }
......
...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod; module* mod = nullptr;
program* prog; tracer* t = nullptr;
tracer* t; module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr) module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager ...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; }
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, nullptr, &trace}.run_pass(p); module_pm{&mod, &trace}.run_pass(p);
} }
} }
...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto mods = prog.get_modules();
auto tree = prog.get_module_tree();
visited.clear();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
if(mod->bypass()) if(mod->bypass())
continue; continue;
module_pm{mod, &prog, &trace}.run_pass(p); if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
if(nparents == 0)
mpm.common_parent = nullptr;
else if(nparents == 1)
mpm.common_parent = parents.begin()->second;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm.common_parent = prog.get_main_module();
mpm.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
...@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const ...@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const
{ {
auto vec_modules = this->get_modules(); auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
os << "migraphx::program p;\n";
for(auto& mod : vec_modules) for(auto& mod : vec_modules)
{ {
os << "module: \"" << mod->name() << "\"" << std::endl; std::string var_name = "m" + mod->name();
names = mod->print_cpp(os, names); os << "migraphx::module_ref " << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module();";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_cpp(os, var_name, names);
os << std::endl; os << std::endl;
} }
} }
...@@ -869,6 +876,23 @@ std::vector<module*> program::get_modules() ...@@ -869,6 +876,23 @@ std::vector<module*> program::get_modules()
return result; return result;
} }
template <class Module, class Map>
void generic_insert_module_tree(Module* pm, Map& m)
{
for(auto* sm : pm->get_sub_modules(true))
{
m.insert(std::make_pair(sm, pm));
generic_insert_module_tree(sm, m);
}
}
std::unordered_multimap<module_ref, module_ref> program::get_module_tree()
{
std::unordered_multimap<module_ref, module_ref> result;
generic_insert_module_tree(this->get_main_module(), result);
return result;
}
template <class Map, class T> template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name) bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{ {
......
...@@ -61,9 +61,7 @@ struct shape_impl ...@@ -61,9 +61,7 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and m_standard = this->elements() == this->element_space() and not skips() and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
...@@ -110,6 +108,15 @@ struct shape_impl ...@@ -110,6 +108,15 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
// Does the shape skip over elements?
bool skips() const
{
assert(m_lens.size() == m_strides.size());
if(elements() == 1)
return false;
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); } std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const bool shape::packed() const
{ {
return this->sub_shapes().empty() and this->elements() == this->element_space(); return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
...@@ -285,10 +293,8 @@ bool shape::transposed() const ...@@ -285,10 +293,8 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::any_of(
this->strides().end(), this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
std::size_t{1},
std::multiplies<std::size_t>()) == 0;
} }
bool shape::scalar() const bool shape::scalar() const
......
...@@ -164,6 +164,7 @@ add_library(migraphx_gpu ...@@ -164,6 +164,7 @@ add_library(migraphx_gpu
deconvolution.cpp deconvolution.cpp
device_name.cpp device_name.cpp
elu.cpp elu.cpp
fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
gemm_impl.cpp gemm_impl.cpp
...@@ -176,7 +177,7 @@ add_library(migraphx_gpu ...@@ -176,7 +177,7 @@ add_library(migraphx_gpu
loop.cpp loop.cpp
lrn.cpp lrn.cpp
leaky_relu.cpp leaky_relu.cpp
mlir_conv.cpp mlir.cpp
multinomial.cpp multinomial.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
...@@ -320,16 +321,26 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") ...@@ -320,16 +321,26 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
find_library(LIBMLIRMIOPEN MLIRMIOpenThin REQUIRED) find_library(MLIRAPI_LIBRARY MLIRMIOpen
PATH_SUFFIXES
# Workaournd broken mlir install
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18 # REQUIRED is not supported before cmake 3.18
if(NOT LIBMLIRMIOPEN) if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpenThin not found") message(FATAL_ERROR "libMLIRMIOpen not found")
else() else()
message(STATUS "Build with libMLIRMIOpenThin: " ${LIBMLIRMIOPEN}) message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif() endif()
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR_MIOPEN_SUPPORT") find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN}) # Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2})
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
...@@ -52,7 +52,7 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>& ...@@ -52,7 +52,7 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::transform( std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs)); k.launch(ctx.get_stream().get(), global, local, std::move(kargs));
return args.back(); return args[get_output_arg(args.size())];
} }
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&) void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
#ifdef MIGRAPHX_MLIR
struct mlir_conv
{
operation op = make_op("convolution");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::mlir_conv"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]});
}
};
MIGRAPHX_REGISTER_OP(mlir_conv);
namespace {
struct find_conv_pointwise
{
// Find a convolution followed by a pointwise operation.
auto matcher() const
{
auto convolution =
match::skip(match::name("contiguous"))(match::name("convolution").bind("convolution"));
return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["convolution"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains({"@literal", "@param", "@return", "convolution", "add", "relu"},
i.name());
}))
return;
// Only fuse with fp32 for now
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->get_shape().type() != shape::type_t::float_type;
}))
return;
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
auto x = mm->add_parameter("x" + std::to_string(names.size()),
conv_ins->inputs().at(0)->get_shape());
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
conv_ins->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(conv_ins->get_operator(), {x, w});
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end());
mpm.get_module().replace_instruction(
ins, mlir_conv{conv_ins->get_operator()}, inputs, {mm});
}
};
} // namespace
#endif
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_conv_pointwise{});
#else
(void)mpm;
#endif
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -335,6 +336,7 @@ void move_standard_front(std::vector<instruction_ref>& args) ...@@ -335,6 +336,7 @@ void move_standard_front(std::vector<instruction_ref>& args)
auto gpu_name(const std::string& s) { return match::name("gpu::" + s); } auto gpu_name(const std::string& s) { return match::name("gpu::" + s); }
namespace {
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
...@@ -700,6 +702,7 @@ struct miopen_fusion ...@@ -700,6 +702,7 @@ struct miopen_fusion
return args.back(); return args.back();
} }
}; };
MIGRAPHX_REGISTER_OP(miopen_fusion)
struct miopen_conv_bias struct miopen_conv_bias
{ {
...@@ -834,15 +837,6 @@ inline auto precompile_name(std::string s) // NOLINT ...@@ -834,15 +837,6 @@ inline auto precompile_name(std::string s) // NOLINT
}); });
} }
template <class... Ms>
auto conv_bias_pointwise(Ms... ms)
{
return precompile_name("pointwise")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
struct find_conv_bias struct find_conv_bias
{ {
context* ctx = nullptr; context* ctx = nullptr;
...@@ -1011,10 +1005,45 @@ struct find_commutative_broadcast ...@@ -1011,10 +1005,45 @@ struct find_commutative_broadcast
m.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
} // namespace
struct find_contiguous
{
auto matcher() const { return match::name("gpu::contiguous"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(make_op("contiguous"))}}),
ins->inputs());
}
};
struct find_contiguous_pointwise
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw = ins->inputs().front();
auto alloc = ins->inputs().back();
auto args = pw->inputs();
args.back() = alloc;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(m, match::find_matches(m,
...@@ -1036,6 +1065,7 @@ void fuse_ops::apply(module& m) const ...@@ -1036,6 +1065,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -38,12 +38,13 @@ struct context; ...@@ -38,12 +38,13 @@ struct context;
struct code_object_op struct code_object_op
{ {
value::binary code_object; value::binary code_object{};
std::string symbol_name; std::string symbol_name = "";
std::size_t global; std::size_t global = 0;
std::size_t local; std::size_t local = 0;
std::vector<shape> expected_inputs; std::vector<shape> expected_inputs{};
shape output; shape output{};
std::int64_t output_arg = -1;
kernel k{}; kernel k{};
template <class Self, class F> template <class Self, class F>
...@@ -66,9 +67,13 @@ struct code_object_op ...@@ -66,9 +67,13 @@ struct code_object_op
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void finalize(context&, const shape&, const std::vector<shape>&); void finalize(context&, const shape&, const std::vector<shape>&);
std::int64_t get_output_arg(std::size_t n) const
{
return output_arg < 0 ? n + output_arg : output_arg;
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return get_output_arg(shapes.size());
} }
friend std::ostream& operator<<(std::ostream& os, const code_object_op& op) friend std::ostream& operator<<(std::ostream& os, const code_object_op& op)
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_MIOPEN_MLIR_CONV_HPP #ifndef MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#define MIGRAPHX_GUARD_RTGLIB_MIOPEN_MLIR_CONV_HPP #define MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
...@@ -30,18 +30,19 @@ ...@@ -30,18 +30,19 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
struct mlir_conv
struct fuse_mlir
{ {
context* ctx; context* ctx = nullptr;
std::string name() const { return "mlir::convolution"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_MLIR_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_MLIR_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
std::string dump_mlir(const module& m);
code_object_op compile_mlir(const context& ctx, const module& m);
instruction_ref insert_mlir(module& m,
instruction_ref ins,
code_object_op co,
const std::vector<instruction_ref>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct mlir_compiler : compiler<mlir_compiler>
{
std::vector<std::string> names() const { return {"gpu::mlir_conv"}; }
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod));
}
compiler_replace insert(code_object_op co) const
{
return [co = std::move(co)](module& m, instruction_ref ins) {
auto mlir = insert_mlir(m, ins, co, ins->inputs());
m.replace_instruction(ins, mlir);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
assert(not ins->module_inputs().empty()); if(op.name() == "contiguous")
auto* pm = ins->module_inputs().front(); {
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); return replace(compile_op(
cpp_generator g; ctx,
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); to_shapes(ins->inputs()),
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); {{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); }
g.add_point_op("sign", else
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); {
g.add_point_op("equal", "migraphx::abs(${0} == ${1})"); assert(not ins->module_inputs().empty());
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); auto* pm = ins->module_inputs().front();
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
g.add_point_op("not", "migraphx::abs(not ${0})"); cpp_generator g;
// Add explict conversions g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.fresult( g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
auto name = g.create_function( g.add_point_op("sign",
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
auto op_names = get_op_names(*pm); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
op_names.push_back("kernel"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
auto op_name_string = join_strings(op_names, "_"); g.add_point_op("not", "migraphx::abs(not ${0})");
return replace( // Add explict conversions
compile_op(ctx, g.fresult([](const shape& s) {
to_shapes(ins->inputs()), return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
} }
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment