Commit 30c49503 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 870a396b 09aaa63e
...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept; ...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace std { namespace std {
template <> template <>
struct hash<migraphx::instruction_ref> struct hash<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = std::size_t; using result_type = std::size_t;
...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref> ...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
}; };
template <> template <>
struct equal_to<migraphx::instruction_ref> struct equal_to<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = bool; using result_type = bool;
......
...@@ -36,22 +36,46 @@ template <class F> ...@@ -36,22 +36,46 @@ template <class F>
struct layernorm_matcher struct layernorm_matcher
{ {
F f; F f;
auto last_axis() const
{
return make_basic_pred_matcher([](instruction_ref ins) {
auto v = ins->get_operator().to_value();
if(not v.contains("axes"))
return false;
auto axes = v["axes"].to_vector<std::size_t>();
if(axes.size() != 1)
return false;
return axes.front() == ins->inputs().front()->get_shape().lens().size() - 1;
});
}
auto reduce_mean() const { return f("reduce_mean")(last_axis()); }
auto x_minus_mean() const auto x_minus_mean() const
{ {
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean")))); return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(reduce_mean())));
} }
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); return reduce_mean()(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(reduce_mean()))))));
} }
auto layernorm_onnx() const auto sqrt_add_eps(const std::string& name) const
{ {
return f("div")(arg(0)(x_minus_mean()), auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return skip_broadcasts(f(name)(arg(0)(any_of(add_eps, variance()))));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)( auto layernorm_onnx() const
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")))))))); {
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,13 +33,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
/** /**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused. * Remove multiple memory allocations using graph coloring to find memory allocations that can be
* reused.
*/ */
struct memory_coloring struct memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory_coloring"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -44,7 +44,7 @@ struct allocate ...@@ -44,7 +44,7 @@ struct allocate
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
migraphx::check_shapes{inputs, *this}.has(0); migraphx::check_shapes{inputs, *this, true}.has(0);
return s; return s;
} }
argument compute(const shape& output_shape, const std::vector<argument>&) const argument compute(const shape& output_shape, const std::vector<argument>&) const
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <array> #include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -73,49 +74,87 @@ struct concat ...@@ -73,49 +74,87 @@ struct concat
} }
return offsets; return offsets;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) // inputs can contain 1 or more shapes (variadic). compute_shape_op ensures there must
// be at least 1.
check_shapes{inputs, *this, true}.same_ndims().same_type();
if(std::none_of(inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
{ {
MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0"); // Static input shapes
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t ll = 0; ll < first_shape_lens.size(); ll++)
{
if(ll != axis)
{
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
{
MIGRAPHX_THROW("CONCAT: all input dimensions should match along axis " +
std::to_string(ll));
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens = first_shape_lens;
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
else if(std::all_of(
const auto& first_shape_lens = inputs.front().lens(); inputs.begin(), inputs.end(), [&](const shape& s) { return s.dynamic(); }))
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis) // Dynamic input shapes
for(std::size_t index = 0; index < inputs[0].ndim(); index++)
{ {
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(index != axis)
return s.lens()[l] == first_shape_lens[l];
}))
{ {
MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match"); if(not std::all_of(inputs.begin(), inputs.end(), [&](const shape& s) {
return s.dyn_dims()[index] == inputs[0].dyn_dims()[index];
}))
MIGRAPHX_THROW("CONCAT: all input dimensions should match in axis " +
std::to_string(index));
} }
} }
std::size_t new_min = 0;
std::size_t new_max = 0;
for(const auto& input : inputs)
{
auto ddim = input.dyn_dims()[axis];
new_min += ddim.min;
new_max += ddim.max;
}
auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0};
return {inputs[0].type(), new_dims};
} }
std::size_t new_dim_axis = 0; else
for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); MIGRAPHX_THROW("CONCAT: Cannot mix static and dynamic input shapes.");
new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(dyn_out.computed_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape = auto slice_shape = shape{dyn_out.computed_shape.type(),
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; input.get_shape().lens(),
auto slice = make_view(slice_shape, output.data() + coffsets[l]); dyn_out.computed_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
std::copy(input.begin(), input.end(), slice.begin()); std::copy(input.begin(), input.end(), slice.begin());
}); });
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <array> #include <array>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -61,35 +62,59 @@ struct gather ...@@ -61,35 +62,59 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this, true}.has(2);
auto lens = inputs[0].lens(); shape data = inputs[0];
auto type = inputs[0].type(); shape indices = inputs[1];
lens.erase(lens.begin() + axis); auto type = data.type();
if(not inputs[1].scalar()) // If index_dims is dynamic, convert the data to dynamic too.
if(indices.dynamic())
{ {
auto ind_lens = inputs[1].lens(); data = data.to_dynamic();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
} }
if(data.dynamic())
// for scalar output
if(lens.empty())
{ {
return {type}; auto dims = data.dyn_dims();
dims.erase(dims.begin() + axis);
if(not indices.scalar())
{
auto index_dims = indices.to_dynamic().dyn_dims();
dims.insert(dims.begin() + axis, index_dims.begin(), index_dims.end());
}
return {type, dims};
} }
else
{
// Both data and indices are static. indices may be scalar
auto lens = data.lens();
lens.erase(lens.begin() + axis);
return {type, lens}; if(not indices.scalar())
{
auto ind_lens = indices.lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
if(lens.empty())
{
return {type};
}
return {type, lens};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens(); auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis]; std::size_t axis_dim_size = lens[axis];
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
if(output_shape.scalar()) if(dyn_out.computed_shape.scalar())
{ {
auto in_index = indices.front(); auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP #define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -47,33 +48,103 @@ struct gathernd ...@@ -47,33 +48,103 @@ struct gathernd
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this, true}.has(2);
auto r = inputs.front().lens().size(); auto i_shape = inputs.back();
auto q = inputs.back().lens().size(); auto data_shape = inputs.front();
auto k = inputs.back().lens().back(); auto r = data_shape.ndim();
auto q = i_shape.ndim();
size_t k;
if(i_shape.dynamic())
{
// the rank of the output is a function of k, so it must be fixed.
if(not i_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = i_shape.dyn_dims().back().min;
}
else
k = i_shape.lens().back();
// Begin input validation checks.
int output_ndim = int(q) + r - k - batch_dims - 1;
if(k > r - batch_dims) if(k > r - batch_dims)
{ {
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " + " cannot be used to access data of rank " +
std::to_string(r - batch_dims)); std::to_string(r - batch_dims));
} }
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1; if(batch_dims >= q or batch_dims >= r)
std::vector<std::size_t> output_lens(output_lens_size); {
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); MIGRAPHX_THROW("GATHERND: rank of an input cannot be less than batch_dims=" +
if(k < r - batch_dims) std::to_string(batch_dims));
}
if(output_ndim < 0)
{
MIGRAPHX_THROW("GATHERND: Indices too large for static data input: k=" +
std::to_string(k));
}
if(migraphx::none_of(inputs, [](auto v) { return v.dynamic(); }))
{
auto indices_lens_iter = i_shape.lens().begin();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape{data_shape.type(), {1}};
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<std::size_t> output_lens(output_ndim);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_lens = data_shape.lens();
std::copy(data_lens.begin() + batch_dims + k,
data_lens.end(),
output_lens.begin() + q - 1);
}
shape output_shape{data_shape.type(), output_lens};
return output_shape;
}
else
{ {
auto data_lens = inputs.front().lens(); // If one or both inputs are dynamic shapes, the output is dynamic.
std::copy( // Make both inputs dynamic to simplify computations.
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); data_shape = data_shape.to_dynamic();
i_shape = i_shape.to_dynamic();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})});
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim);
std::copy(i_shape.dyn_dims().begin(),
i_shape.dyn_dims().begin() + q - 1,
output_dims.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_dims = data_shape.dyn_dims();
std::copy(data_dims.begin() + batch_dims + k,
data_dims.begin() + r,
output_dims.begin() + q - 1);
}
shape output_shape(data_shape.type(), output_dims);
return output_shape;
} }
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
......
...@@ -143,16 +143,22 @@ struct nonmaxsuppression ...@@ -143,16 +143,22 @@ struct nonmaxsuppression
void sort() void sort()
{ {
std::sort(x.begin(), x.end()); if(x[0] > x[1])
std::sort(y.begin(), y.end()); {
std::swap(x[0], x[1]);
}
if(y[0] > y[1])
{
std::swap(y[0], y[1]);
}
} }
std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
double area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(x[0] <= x[1]);
assert(std::is_sorted(y.begin(), y.end())); assert(y[0] <= y[1]);
return (x[1] - x[0]) * (y[1] - y[0]); return (x[1] - x[0]) * (y[1] - y[0]);
} }
}; };
...@@ -190,14 +196,10 @@ struct nonmaxsuppression ...@@ -190,14 +196,10 @@ struct nonmaxsuppression
{ {
intersection[i][0] = std::max(b1[i][0], b2[i][0]); intersection[i][0] = std::max(b1[i][0], b2[i][0]);
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} if(intersection[i][0] > intersection[i][1])
{
std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y}; return false;
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { }
return not std::is_sorted(bx.begin(), bx.end());
}))
{
return false;
} }
const double area1 = b1.area(); const double area1 = b1.area();
...@@ -265,31 +267,31 @@ struct nonmaxsuppression ...@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4; auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold); auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(not boxes_heap.empty() && while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
// Check with existing selected boxes for this class, remove box if it // select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
// exceeds the IOU (Intersection Over Union) threshold // threshold with the selected box
const auto next_top_score = boxes_heap.top(); const auto next_top_score = boxes_heap.top();
bool not_selected = boxes_heap.pop();
std::any_of(selected_boxes_inside_class.begin(), selected_boxes_inside_class.push_back(next_top_score);
selected_boxes_inside_class.end(), selected_indices.push_back(batch_idx);
[&](auto selected_index) { selected_indices.push_back(class_idx);
return this->suppress_by_iou( selected_indices.push_back(next_top_score.second);
batch_box(batch_boxes_start, next_top_score.second), std::priority_queue<std::pair<double, int64_t>> remainder_boxes;
batch_box(batch_boxes_start, selected_index.second), while(not boxes_heap.empty())
iou_threshold);
});
if(not not_selected)
{ {
selected_boxes_inside_class.push_back(next_top_score); auto iou_candidate_box = boxes_heap.top();
selected_indices.push_back(batch_idx); if(not this->suppress_by_iou(
selected_indices.push_back(class_idx); batch_box(batch_boxes_start, iou_candidate_box.second),
selected_indices.push_back(next_top_score.second); batch_box(batch_boxes_start, next_top_score.second),
iou_threshold))
{
remainder_boxes.push(iou_candidate_box);
}
boxes_heap.pop();
} }
boxes_heap.pop(); boxes_heap = remainder_boxes;
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
......
...@@ -31,18 +31,30 @@ namespace migraphx { ...@@ -31,18 +31,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// different attributes /**
// 1) use_input(default)/use_output * `normalize_attribute` settings:
// 2) use_rank(default)/use_len * Note that default options are not included as enums.
// 3) clip_min(default)/not_clip_min * 1. `use_input` (default) vs. `use_output`:
// 3.1) include_min(default)/exclude_min * Affects the rank of the attribute.
// 4) clip_max(default)/not_clip_max * `use_input -> lens.size()`, `use_output -> lens.size() + vec.size()`.
// 4.1) exclude_max(default)/include_max * 2. use_rank (default) vs use_len:
// 5) normalize padding * `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* Include or exclude the minimum value/index for range checking and clipping.
* 5. `clip_max` vs. `not_clip_max` (default):
* Clip values greater than the maximum or not.
* 6. `include_max` vs. `exclude_max` (default):
* Include or exclude the maximum value/index for range checking and clipping.
* 7. `normalize_padding`:
* To normalize the padding to `2*(pad ndim)` dimensions.
*/
enum class normalize_attribute enum class normalize_attribute
{ {
use_len,
use_output, use_output,
use_len,
clip_max, clip_max,
clip_min, clip_min,
include_max, include_max,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -60,6 +61,7 @@ struct reverse ...@@ -60,6 +61,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1);
return inputs[0].with_lens(inputs[0].lens()); return inputs[0].with_lens(inputs[0].lens());
} }
......
...@@ -28,44 +28,89 @@ ...@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template <class Derived> template <class Derived>
struct scatternd_op : op_name<Derived> struct scatternd_op : op_name<Derived>
{ {
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this, true}.has(3);
auto r = inputs.front().lens().size(); auto data_shape = inputs.front();
auto q = inputs.at(1).lens().size(); auto index_shape = inputs.at(1);
auto k = inputs.at(1).lens().back(); auto upd_shape = inputs.back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens(); auto r = data_shape.ndim();
auto data_lens = inputs.front().lens(); auto q = index_shape.ndim();
size_t k;
if(index_shape.dynamic())
{
// the rank of the output is a function of k, so k must be fixed.
if(not index_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = index_shape.dyn_dims().back().min;
}
else
k = index_shape.lens().back();
// Checks on the sizes of input tensors
if(q + r != upd_shape.ndim() + k + 1)
MIGRAPHX_THROW("ScatterND: ranks of inputs don't match. " + std::to_string(q) + " + " +
std::to_string(r) + " - " + std::to_string(k) +
" - 1 != " + std::to_string(upd_shape.ndim()));
if(k > r) if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) + MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r)); " is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1))) // Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] " // It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"); // but any dynamic dimension that's compared to a static dimension must be fixed.
auto s = inputs.front(); auto ind_dims = index_shape.to_dynamic().dyn_dims();
if(s.broadcasted()) auto upd_dims = upd_shape.to_dynamic().dyn_dims();
auto data_dims = data_shape.to_dynamic().dyn_dims();
// Check that corresponding portions of tensor shapes match.
if(not(std::equal(ind_dims.begin(), ind_dims.begin() + q - 1, upd_dims.begin()) and
std::equal(data_dims.begin() + k, data_dims.end(), upd_dims.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. Update dimensions must match "
"indices and data.");
if(data_shape.dynamic())
return data_shape;
else if(data_shape.broadcasted())
{ {
return {s.type(), s.lens()}; return {data_shape.type(), data_shape.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return data_shape.with_lens(data_shape.lens());
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived> ...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto updates_std = shape{updates_shape.type(), updates_shape.lens()}; auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size(); auto q = indices_shape.ndim();
auto r = output_shape.lens().size(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived> ...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std::copy(index_start, index_end, out_idx.begin()); std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); });
}); });
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct select_module
{
shape output_dyn_shapes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_dyn_shapes, "output_dyn_shapes"));
}
std::string name() const { return "select_module"; }
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this, true}.has_at_least(1);
return shape{output_dyn_shapes};
}
std::vector<std::string> get_input_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return not contains(pn, "#output_"); });
return ret;
}
std::vector<std::string> get_output_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return contains(pn, "#output_"); });
return ret;
}
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// Find submodule with input parameter shapes exactly the same as the input instruction
// arguments. Assuming instruction arguments are in the same order as the instruction
// parameters.
auto module_iter =
std::find_if(submodule_list.cbegin(), submodule_list.cend(), [&](module_ref mr) {
auto in_param_names = get_input_parameter_names(mr);
auto param_shapes = mr->get_parameter_shapes();
assert(in_param_names.size() <= args.size());
return std::equal(
in_param_names.cbegin(),
in_param_names.cend(),
args.cbegin(),
[&](auto p_name, auto a) { return a.get_shape() == param_shapes[p_name]; });
});
if(module_iter == submodule_list.end())
{
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
}
auto* module_to_run = *module_iter;
std::unordered_map<std::string, argument> p_map;
// add input parameters to parameter_map
auto in_param_names = get_input_parameter_names(module_to_run);
assert(in_param_names.size() <= args.size());
std::transform(in_param_names.begin(),
in_param_names.end(),
args.begin(),
std::inserter(p_map, p_map.end()),
[&](auto&& name, auto&& a) { return std::make_pair(name, a); });
// One tuple output parameter in main module to multiple output parameters in submodule
auto out_param_names = get_output_parameter_names(module_to_run);
auto output_sub_objects = args.back().get_sub_objects();
assert(out_param_names.size() == output_sub_objects.size());
std::transform(out_param_names.begin(),
out_param_names.end(),
output_sub_objects.begin(),
std::inserter(p_map, p_map.end()),
[&](auto&& name, auto&& a) {
auto ps = module_to_run->get_parameter_shape(name);
if(a.get_shape() != ps)
{
assert(ps.bytes() == a.get_shape().bytes());
return std::make_pair(name, a.reshape(ps));
}
else
{
return std::make_pair(name, a);
}
});
auto results = run(module_to_run, p_map);
return argument{results};
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -46,6 +47,10 @@ struct slice ...@@ -46,6 +47,10 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize = value::object{};
...@@ -65,14 +70,6 @@ struct slice ...@@ -65,14 +70,6 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const auto compute_offset(const shape& s) const
{ {
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
...@@ -83,14 +80,14 @@ struct slice ...@@ -83,14 +80,14 @@ struct slice
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis]; offset += starts[i] * strides[axis];
} }
} }
else else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
offset += fix_index(lens, axis, starts[axis]) * strides[axis]; offset += starts[axis] * strides[axis];
} }
} }
return offset; return offset;
...@@ -98,37 +95,81 @@ struct slice ...@@ -98,37 +95,81 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; check_shapes{inputs, *this, true}.has(1);
auto t = input_shape.type(); auto input_shape = inputs[0];
const auto& old_lens = input_shape.lens(); auto t = input_shape.type();
const auto& old_strides = input_shape.strides();
if(std::any_of( // TODO: When support for dynamic shapes is added to normalize_attributes,
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); })) // remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
if(starts.size() != axes.size() or axes.size() != ends.size()) // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
if(input_shape.dynamic())
{
old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
}
else
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); old_lens = input_shape.lens();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides = input_shape.strides();
} }
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
new_lens[axis] = size_t sliced_length = ends[i] - starts[i];
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]); // A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens[axis] = std::min(new_lens[axis], sliced_length);
if(input_shape.dynamic())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
}
}
if(input_shape.dynamic())
{
return shape{t, new_mins, new_lens, new_opts};
}
else
{
return shape{t, new_lens, old_strides};
} }
return shape{t, new_lens, old_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }}; auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size();
return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -42,9 +42,17 @@ struct where ...@@ -42,9 +42,17 @@ struct where
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_dims(); check_shapes{inputs, *this, true}.has(3).same_dims();
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto s2 = inputs.at(2); auto s2 = inputs.at(2);
if(s1.dynamic() or s2.dynamic())
{
if(s1 == s2)
return s1;
MIGRAPHX_THROW("WHERE: dynamic input shapes must be the same");
}
// Compare two static shapes, returning a standard shape
if(s1 == s2 and s1.packed()) if(s1 == s2 and s1.packed())
{ {
return s1; return s1;
...@@ -63,12 +71,12 @@ struct where ...@@ -63,12 +71,12 @@ struct where
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) { visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) { args[0].visit([&](const auto condition) {
par_for(output_shape.elements(), par_for(dyn_out.computed_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; }); [&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
}); });
}); });
......
...@@ -140,6 +140,8 @@ template <class T> ...@@ -140,6 +140,8 @@ template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
......
...@@ -21,18 +21,28 @@ ...@@ -21,18 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#ifndef MIGRAPHX_GUARD_PASS_CONFIG_HPP #include <string>
#define MIGRAPHX_GUARD_PASS_CONFIG_HPP #include <migraphx/instruction_ref.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING) struct module_pass_manager;
/**
* Runs several passes in a loop
*/
struct optimize_module
{
std::string name() const { return "optimize_module"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_PASS_CONFIG_HPP
#endif
...@@ -33,15 +33,36 @@ ...@@ -33,15 +33,36 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void unregister_op(const std::string& op_name);
namespace detail {
struct op_handler
{
operation op;
std::string name;
op_handler(const operation& op_r) : op(op_r), name(op.name()){};
~op_handler() { unregister_op(name); }
};
} // namespace detail
void register_op_init();
void register_op(const operation& op); void register_op(const operation& op);
operation load_op(const std::string& name); operation load_op(const std::string& name);
bool has_op(const std::string& name); bool has_op(const std::string& name);
std::vector<std::string> get_operators(); std::vector<std::string> get_operators();
template <class T> template <class T>
void register_op() void register_op()
{ {
register_op(T{}); register_op_init(); // instantiate static op_map;
static auto op_h = detail::op_handler(T{});
register_op(op_h.op);
} }
struct register_op_action struct register_op_action
......
...@@ -33,14 +33,28 @@ ...@@ -33,14 +33,28 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void register_target_init();
void register_target(const target& t); void register_target(const target& t);
void unregister_target(const std::string& name);
target make_target(const std::string& name); target make_target(const std::string& name);
std::vector<std::string> get_targets(); std::vector<std::string> get_targets();
namespace detail {
struct target_handler
{
target t;
std::string target_name;
target_handler(const target& t_r) : t(t_r), target_name(t.name()) {}
~target_handler() { unregister_target(target_name); }
};
} // namespace detail
template <class T> template <class T>
void register_target() void register_target()
{ {
register_target(T{}); register_target_init();
static auto t_h = detail::target_handler(T{});
register_target(t_h.t);
} }
struct register_target_action struct register_target_action
......
...@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
/**
* Replace `allocate` instructions with target allocations or output parameters.
*/
struct replace_allocate struct replace_allocate
{ {
allocation_model model; allocation_model model;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <type_traits> #include <type_traits>
...@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&) ...@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&)
return value::object{}; return value::object{};
} }
template <class T, class U> template <class T>
value to_value_impl(rank<1>, const std::pair<T, U>& x) auto to_value_impl(rank<1>, const T& x) -> decltype(std::tuple_size<T>{}, value{})
{ {
value result = value::array{};
return {x.first, x.second}; repeat_c<std::tuple_size<T>{}>([&](auto i) { result.push_back(to_value(std::get<i>(x))); });
return result;
} }
template <class T> template <class T>
...@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x) ...@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result; return result;
} }
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if(x.has_value())
return to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x) value to_value_impl(rank<5>, const T& x)
{ {
return std::int64_t{x}; return std::int64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x) value to_value_impl(rank<6>, const T& x)
{ {
return std::uint64_t{x}; return std::uint64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x) value to_value_impl(rank<7>, const T& x)
{ {
return double{x}; return double{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x) value to_value_impl(rank<8>, const T& x)
{ {
return x; return x;
} }
inline value to_value_impl(rank<8>, const std::string& x) { return x; } inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T> template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{ {
return migraphx_to_value(x); return migraphx_to_value(x);
} }
template <class T> template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value()) auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{ {
return x.to_value(); return x.to_value();
} }
template <class T> template <class T>
auto to_value_impl(rank<11>, const T& x) auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{}) -> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{ {
value v; value v;
...@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x) ...@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<1>, const value& v, T& x) auto from_value_impl(rank<1>, const value& v, T& x) -> decltype(std::tuple_size<T>{}, void())
{
repeat_c<std::tuple_size<T>{}>(
[&](auto i) { std::get<i>(x) = from_value<std::tuple_element_t<i, T>>(v[i]); });
}
template <class T>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void()) -> decltype(x.insert(x.end(), *x.begin()), void())
{ {
x.clear(); x.clear();
...@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x) ...@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x)
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x) auto from_value_impl(rank<3>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void()) -> decltype(x.insert(x.end(), *x.begin()), void())
{ {
x.clear(); x.clear();
...@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) ...@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
...@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi ...@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi
} }
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<4>, const value& v, T& x) void from_value_impl(rank<5>, const value& v, T& x)
{ {
reflect_each(x, [&](auto& y, const std::string& name) { reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>; using type = std::decay_t<decltype(y)>;
...@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x) ...@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x)
}); });
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T>
void from_value_impl(rank<5>, const value& v, T& x) void from_value_impl(rank<6>, const value& v, optional<T>& x)
{ {
x = v.to<T>(); if(not v.is_null())
x = from_value<T>(v);
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})>
void from_value_impl(rank<6>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va ...@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va
template <class T> template <class T>
value to_value(const T& x) value to_value(const T& x)
{ {
return detail::to_value_impl(rank<11>{}, x); return detail::to_value_impl(rank<12>{}, x);
} }
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<9>{}, v, x); detail::from_value_impl(rank<10>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment