Commit d5636acd authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_dim_onnx_parser' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv

parents fff09052 5f215b71
......@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
if(trace > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace)
if(trace > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
mod.debug_print(ins);
get_module(mod).debug_print(ins);
}
m.apply(mod, r);
match = true;
......@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
}
/// Find matches in a module
template <class... Ms>
void find_matches(module& mod, Ms&&... ms)
template <class Mod, class... Ms>
void find_matches(Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(mod))
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(mod, ins, ms...);
}
......
......@@ -120,9 +120,33 @@ struct module
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
......@@ -140,6 +164,10 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref insert_literal(instruction_ref ins, literal l);
instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......@@ -179,7 +207,9 @@ struct module
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const;
print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
......@@ -201,6 +231,8 @@ struct module
std::unique_ptr<module_impl> impl;
};
inline module& get_module(module& m) { return m; }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -32,7 +32,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
using module_ref = module*;
using const_module_ref = const module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const
{
// requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens();
// check input shape
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!");
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
}
std::vector<int64_t> out_lens(2);
......@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct box
{
std::array<float, 2> x;
std::array<float, 2> y;
std::array<double, 2> x;
std::array<double, 2> y;
void sort()
{
......@@ -83,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end());
}
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const
double area() const
{
assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end()));
......@@ -94,29 +101,29 @@ struct nonmaxsuppression
};
template <class T>
box batch_box(const T* boxes, std::size_t bidx) const
box batch_box(T boxes, std::size_t box_idx) const
{
box result{};
const T* start = boxes + 4 * bidx;
auto start = boxes + 4 * box_idx;
if(center_point_box)
{
float half_width = start[2] / 2.0f;
float half_height = start[3] / 2.0f;
float x_center = start[0];
float y_center = start[1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
double half_width = start[2] / 2.0;
double half_height = start[3] / 2.0;
double x_center = start[0];
double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
}
else
{
result.x = {start[1], start[3]};
result.y = {start[0], start[2]};
result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
}
return result;
}
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const
inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{
b1.sort();
b2.sort();
......@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]);
}
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y};
std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end());
}))
......@@ -136,115 +143,124 @@ struct nonmaxsuppression
return false;
}
const float area1 = b1.area();
const float area2 = b2.area();
const float intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area;
const double area1 = b1.area();
const double area2 = b2.area();
const double intersection_area = intersection.area();
const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{
return false;
}
const float intersection_over_union = intersection_area / union_area;
const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
// filter boxes below score_threshold
template <class T>
std::priority_queue<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{
argument result{output_shape};
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); });
std::size_t max_output_boxes_per_class = 0;
float iou_threshold = 0.0f;
float score_threshold = 0.0f;
if(args.size() > 2)
{
max_output_boxes_per_class = args.at(2).at<std::size_t>();
}
// max_output_boxes_per_class is 0, no output
if(max_output_boxes_per_class == 0)
{
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap =
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
transform_if(
scores_start,
scores_start + num_boxes,
insert_to_boxes_heap,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
return boxes_heap;
}
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>());
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
// iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0];
auto cidx = idx[1];
std::size_t score_offset = (bidx * class_num + cidx) * box_num;
const float* batch_boxes = boxes + bidx * box_num * 4;
std::priority_queue<std::pair<float, int64_t>> sorted_boxes;
auto insert_to_sorted_boxes =
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
auto batch_idx = idx[0];
auto class_idx = idx[1];
// index offset for this class
auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
// iterator to first value of this batch
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() &&
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top();
// Check with existing selected boxes for this class, suppress if exceed the IOU
// (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second),
batch_box(batch_boxes, selected_index.second),
iou_threshold);
});
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top();
bool not_selected =
std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes_start, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx);
selected_indices.push_back(cidx);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
sorted_boxes.pop();
boxes_heap.pop();
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) {
std::copy(selected_indices.begin(), selected_indices.end(), out.begin());
std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0)
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
});
return result;
......
......@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
......@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
......@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0;
for(auto i : range(new_size))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{
new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
new_strides[i] = old_lens[0] * old_strides[0];
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else // unsqueeze on middle or last axes
else
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
}
else
......
......@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
......@@ -84,6 +86,9 @@ struct program
instruction_ref validate() const;
target_assignments get_target_assignments(const std::vector<target>& targets,
assignment_options options = assignment_options{});
void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const;
......
......@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std::transform(r.begin(), r.end(), it, f);
}
template <class Range1, class Range2, class Iterator, class F>
void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
{
std::transform(r1.begin(), r1.end(), r2.begin(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
......
......@@ -256,6 +256,10 @@ struct shape
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
auto is_integral() const { return std::is_integral<type>{}; }
auto is_signed() const { return std::is_signed<type>{}; }
auto is_unsigned() const { return std::is_unsigned<type>{}; }
template <class U>
type* from(U* buffer, std::size_t n = 0) const
{
......@@ -314,8 +318,9 @@ struct shape
const std::vector<shape>& sub_shapes() const;
/*!
* Returns size of the data buffer.
* Assuming a packed shape, returns maximum size of the data buffer for dynamic shape.
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std::size_t element_space() const;
......
......@@ -44,8 +44,8 @@ auto with_char(F f)
return [=](unsigned char c) -> bool { return f(c); };
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
inline void
replace_string_inplace(std::string& subject, const std::string& search, const std::string& replace)
{
size_t pos = 0;
while((pos = subject.find(search, pos)) != std::string::npos)
......@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
{
replace_string_inplace(subject, search, replace);
return subject;
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
......@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution.
*/
context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/**
* @brief copy an argument to the current target.
*
......@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg;
}
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
......@@ -117,6 +132,8 @@ struct target
//
context get_context() const;
// (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const;
// (optional)
argument copy_from(const argument& input) const;
......@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context();
}
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0;
virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0;
};
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T>
static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
......@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
......@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -35,6 +35,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -196,6 +197,62 @@ void module::assign(const module& m)
}
}
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......@@ -334,61 +391,56 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
std::vector<instruction_ref>
module::add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
return this->insert_instructions(this->end(), instructions, std::move(map_ins));
}
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
std::vector<instruction_ref>
module::add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
std::vector<instruction_ref>
module::add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), start, last, std::move(map_ins));
}
instruction_ref module::add_literal(literal l)
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
}
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
instruction_ref module::add_outline(const shape& s)
{
impl->push_front({builtin::outline{s}, s, {}});
......@@ -397,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
return insert_parameter(begin(), std::move(name), std::move(s));
}
instruction_ref module::add_return(std::vector<instruction_ref> args)
......@@ -413,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result;
}
instruction_ref module::insert_literal(instruction_ref ins, literal l)
{
impl->emplace(ins, std::move(l));
return std::prev(ins);
}
instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->insert(ins, {builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return std::prev(ins);
}
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
......@@ -706,44 +769,33 @@ void module::print_graph(std::ostream& os, bool brief) const
os << "}" << std::endl;
}
static std::string cpp_var_name(const std::string& name)
static std::string to_c_id(const std::string& name, char rep = '_')
{
return "m" + replace_string(name, "@", "x");
std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(contains(id, "__"))
replace_string_inplace(id, "__", "_");
return id;
}
static std::string cpp_op_var(const std::string& name, instruction_ref ins)
static std::string cpp_var_name(const std::string& name)
{
return replace_string(name, "@", ins->name());
return to_c_id("x_" + replace_string(name, ":", "_module_"));
}
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op)
static void print_make_op(std::ostream& os, const operation& op)
{
std::string x = to_string(op);
if(contains(x, "["))
os << "migraphx::make_op(" << enclose_name(op.name());
auto v = op.to_value();
if(not v.empty())
{
auto start = x.find('[');
auto end = x.find(']');
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
os << ", "
<< "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")";
}
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
......@@ -756,22 +808,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
}
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
module::print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
os << "migraphx::module p;" << std::endl;
unsigned long seed = 0;
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
os << mname << "->add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
......@@ -789,17 +844,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
os << mname << "->add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << "->add_return({";
os << join_strings(input_vars, ", ");
os << "});" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(ins_names.at(input));
}
assert(ins->name().front() != '@');
os << mname << "->add_instruction(";
print_make_op(os, ins->get_operator());
os << ", " << join_strings(input_vars, ", ");
os << ");" << std::endl;
}
},
......@@ -808,7 +868,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
return names;
}
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
......
......@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return mm->validate();
}
target_assignments program::get_target_assignments(const std::vector<target>& targets,
assignment_options options)
{
const auto m = options.metric;
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
}
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options)
......@@ -542,12 +561,14 @@ static void mod_from_val(module_ref mod,
if(name == "@param")
{
output = mod->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
output = mod->insert_parameter(mod->end(),
fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal")));
output =
mod->insert_literal(mod->end(), migraphx::from_value<literal>(node.at("literal")));
}
else
{
......@@ -582,11 +603,11 @@ static void mod_from_val(module_ref mod,
}
else if(module_inputs.empty())
{
output = mod->add_instruction(op, inputs);
output = mod->insert_instruction(mod->end(), op, inputs);
}
else
{
output = mod->add_instruction(op, inputs, module_inputs);
output = mod->insert_instruction(mod->end(), op, inputs, module_inputs);
}
}
output->set_normalized(normalized);
......@@ -828,10 +849,17 @@ void program::print_cpp(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "migraphx::program p;\n";
for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print_cpp(os, names);
std::string var_name = "m" + mod->name();
os << "migraphx::module_ref " << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module();";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_cpp(os, var_name, names);
os << std::endl;
}
}
......
......@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects());
else
else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
}
......@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal l = migraphx::from_value<literal>(v);
a = l.get_argument();
}
else
else if(v.contains("sub"))
{
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
}
......
......@@ -272,7 +272,7 @@ struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
}
void apply(module& m, const match::matcher_result& mr) const
......@@ -601,6 +601,69 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct find_slice_transpose
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](
match::name("slice")(match::output(match::name("transpose")))));
}
static std::vector<int64_t> find_common_perm(const std::vector<instruction_ref>& transposes)
{
std::map<std::vector<int64_t>, int64_t> count;
for(auto t : transposes)
{
auto perm = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
count[perm]++;
}
return std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }))
->first;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
std::vector<instruction_ref> splits;
std::copy_if(ins->outputs().begin(),
ins->outputs().end(),
std::back_inserter(splits),
[&](instruction_ref out) {
return out->name() == "slice" and out->outputs().size() == 1 and
out->outputs().front()->name() == "transpose";
});
if(splits.size() < 2)
return;
std::vector<instruction_ref> transposes;
std::transform(splits.begin(),
splits.end(),
std::back_inserter(transposes),
[](auto split) { return split->outputs().front(); });
auto perm = find_common_perm(transposes);
auto iperm = invert_permutation(perm);
auto pre = m.insert_instruction(
std::next(ins), make_op("transpose", {{"permutation", perm}}), ins);
for(auto i : range(transposes.size()))
{
auto split = splits[i];
auto t = transposes[i];
auto op = any_cast<op::slice>(split->get_operator());
std::transform(op.axes.begin(), op.axes.end(), op.axes.begin(), [&](auto axis) {
return iperm[axis];
});
auto new_ins = m.insert_instruction(t, op, pre);
if(t->get_operator() != pre->get_operator())
{
auto curr = t->get_operator().to_value()["permutation"].to_vector<int64_t>();
new_ins = m.insert_instruction(
t, make_op("transpose", {{"permutation", reorder_dims(iperm, curr)}}), new_ins);
}
m.replace_instruction(t, new_ins);
}
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 2; i++)
......@@ -616,6 +679,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{},
find_nested_slice{},
find_nested_concat{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m);
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target_assignments.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target)
{
assignments.emplace(ins, target);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -25,6 +25,7 @@
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -52,6 +53,7 @@ struct cpu_literal
return os;
}
};
MIGRAPHX_REGISTER_OP(cpu_literal);
void write_literals::apply(module& m) const
{
......
......@@ -164,6 +164,7 @@ add_library(migraphx_gpu
deconvolution.cpp
device_name.cpp
elu.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
......@@ -176,7 +177,7 @@ add_library(migraphx_gpu
loop.cpp
lrn.cpp
leaky_relu.cpp
mlir_conv.cpp
mlir.cpp
multinomial.cpp
nonzero.cpp
pack_args.cpp
......@@ -320,16 +321,26 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
find_library(LIBMLIRMIOPEN MLIRMIOpenThin REQUIRED)
find_library(MLIRAPI_LIBRARY MLIRMIOpen
PATH_SUFFIXES
# Workaournd broken mlir install
lib/ lib/lib)
# REQUIRED is not supported before cmake 3.18
if(NOT LIBMLIRMIOPEN)
message(FATAL_ERROR "libMLIRMIOpenThin not found")
if(NOT MLIRAPI_LIBRARY)
message(FATAL_ERROR "libMLIRMIOpen not found")
else()
message(STATUS "Build with libMLIRMIOpenThin: " ${LIBMLIRMIOPEN})
message(STATUS "Build with libMLIRMIOpen: " ${MLIRAPI_LIBRARY})
endif()
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR_MIOPEN_SUPPORT")
target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN})
find_path(MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h)
# Workaround MLIR broken installation
find_path(MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_include_directories(migraphx_gpu SYSTEM PRIVATE ${MLIRAPI_HEADERS} ${MLIRAPI_HEADERS2})
target_link_libraries(migraphx_gpu PUBLIC ${MLIRAPI_LIBRARY})
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment