Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
......@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -47,9 +38,9 @@ struct mod : binary<mod>
{
auto a = base_attributes();
a["commutative"] = false;
a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})";
return a;
}
std::string point_function() const { return "mod"; }
auto apply() const
{
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
......
......@@ -26,64 +26,105 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct multibroadcast
{
std::vector<std::size_t> output_lens;
std::vector<std::size_t> output_lens = {};
// optional attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "out_lens"));
return pack(f(self.output_lens, "out_lens"), f(self.output_dyn_dims, "out_dyn_dims"));
}
std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
check_shapes{inputs, *this, true}.has(1, 2);
if(input.lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
auto t = inputs.at(0).type();
auto s0 = inputs.at(0);
if(input.lens().size() > output_lens.size())
if(s0.max_lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
}
auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(bcast_lens[i + offset] == s0.lens()[i])
{
bcast_strides[i + offset] = s0.strides()[i];
}
}
return bcast_strides;
};
if(inputs.size() == 1)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
if(s0.lens().size() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
auto offset = output_lens.size() - s0.lens().size();
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
auto bcast_strides = make_bcast_strides(output_lens, offset);
return {t, output_lens, std::move(bcast_strides)};
}
else
{
if(output_lens[i + offset] == input.lens()[i])
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
{
bcast_strides[i + offset] = input.strides()[i];
if(not output_dyn_dims.empty())
{
return {t, output_dyn_dims};
}
return {t, compute_broadcasted_dyn_dims(s0, s1)};
}
else
{
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
auto offset = bcast_lens.size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
}
}
return {t, output_lens, bcast_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -45,11 +45,13 @@ namespace op {
struct nonmaxsuppression
{
bool center_point_box = false;
bool use_dyn_output = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.center_point_box, "center_point_box"));
return pack(f(self.center_point_box, "center_point_box"),
f(self.use_dyn_output, "use_dyn_output"));
}
std::string name() const { return "nonmaxsuppression"; }
......@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const
{
// requires at least 2 inputs
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens();
check_shapes{{inputs.at(0), inputs.at(1)}, *this, true}.only_dims(3).same_ndims();
auto boxes_max_lens = inputs.at(0).max_lens();
// num batches * num boxes
const auto max_num_boxes = boxes_max_lens.at(0) * boxes_max_lens.at(1);
// check input shape
if(lens[1] != inputs.at(1).lens()[2])
auto fixed_shape_error_check = [&]() {
auto lens = inputs.front().lens();
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
}
};
if(use_dyn_output)
{
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
if(inputs.at(0).dynamic())
{
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const auto boxes_dims = inputs.at(0).dyn_dims();
const auto scores_dims = inputs.at(1).dyn_dims();
if(boxes_dims.at(1) != scores_dims.at(2))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input");
}
if(boxes_dims.at(0) != scores_dims.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input");
}
}
else if(inputs.at(1).dynamic())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const auto scores_dims = inputs.at(1).dyn_dims();
const auto boxes_lens = inputs.at(0).lens();
if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max != boxes_lens.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched");
}
if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max != boxes_lens.at(1))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches");
}
}
else
{
fixed_shape_error_check();
}
std::vector<shape::dynamic_dimension> out_lens = {};
out_lens.push_back({0, max_num_boxes, 0});
out_lens.push_back({3, 3, 0});
return {shape::int64_type, out_lens};
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
else
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
if(inputs.at(0).dynamic() or inputs.at(1).dynamic())
{
MIGRAPHX_THROW(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false");
}
fixed_shape_error_check();
std::vector<std::size_t> out_lens = {max_num_boxes, 3};
return {shape::int64_type, out_lens};
}
std::vector<int64_t> out_lens(2);
out_lens.at(0) = lens.at(1);
out_lens.at(1) = 3;
return {shape::int64_type, out_lens};
}
struct box
......@@ -181,13 +237,13 @@ struct nonmaxsuppression
}
template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
std::size_t compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& max_output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
......@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
......@@ -210,7 +266,7 @@ struct nonmaxsuppression
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() &&
while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
// Check with existing selected boxes for this class, remove box if it
......@@ -237,11 +293,14 @@ struct nonmaxsuppression
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
return selected_indices.size() / 3;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// make buffer of maximum size
shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape};
std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
......@@ -249,22 +308,29 @@ struct nonmaxsuppression
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
std::size_t num_selected = 0;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
num_selected = compute_nms(output,
boxes,
scores,
max_output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
});
return result;
if(use_dyn_output)
{
return result.reshape({output_shape.type(), {num_selected, 3}});
}
else
{
return result;
}
}
};
......
......@@ -64,8 +64,8 @@ struct pooling
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != lengths.size())
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
......@@ -83,7 +83,7 @@ struct pooling
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
if(input_size != padding_size / 2 + 2 and input_size != padding_size + 2)
{
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
}
......
......@@ -41,9 +41,8 @@ struct quant_convolution
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1;
bool use_dynamic_same_auto_pad = false;
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -52,8 +51,7 @@ struct quant_convolution
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
f(self.group, "group"));
}
value attributes() const
......@@ -65,8 +63,8 @@ struct quant_convolution
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != dilation.size())
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
}
......
......@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
......
......@@ -78,7 +78,7 @@ struct slice
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(!axes.empty())
if(not axes.empty())
{
for(std::size_t i = 0; i < axes.size(); i++)
{
......@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
}
if(starts.size() != axes.size() || axes.size() != ends.size())
if(starts.size() != axes.size() or axes.size() != ends.size())
{
MIGRAPHX_THROW("SLICE: inconsistent sizes");
}
......
......@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -54,52 +55,90 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
if(input_shape.dynamic())
{
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not contains(one_dyn_dims, input_shape.dyn_dims()[axis]);
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
if(old_lens[i] != 1)
for(auto i : range(input_shape.ndim()))
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
auto dd = input_shape.dyn_dims()[i];
if(not contains(one_dyn_dims, dd))
{
dyn_dims.push_back(dd);
}
}
}
}
else
{
for(auto i : range(old_lens.size()))
else
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
for(auto i : range(input_shape.ndim()))
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
}
}
}
if(new_lens.empty())
{
return shape{type};
return {input_shape.type(), dyn_dims};
}
else
{
return shape{type, new_lens, new_strides};
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
{
if(old_lens[i] != 1)
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
else
{
for(auto i : range(old_lens.size()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -59,7 +59,7 @@ struct transpose
}
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
if(not std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
}
......
......@@ -30,6 +30,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -62,9 +63,9 @@ struct unary : op_name<Derived>
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
check_shapes{inputs, static_cast<const Derived&>(*this), true}.has(1);
auto s = inputs.at(0);
if(s.scalar())
if(s.dynamic() or s.scalar())
{
return s;
}
......@@ -78,9 +79,9 @@ struct unary : op_name<Derived>
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(),
......
......@@ -29,11 +29,20 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze
{
std::vector<int64_t> axes;
......@@ -56,63 +65,89 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
if(input_shape.dynamic())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
if(not steps.empty())
{
MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size();
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(auto i : range(new_size))
{
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
}
else
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
return shape{type, new_lens, new_strides};
}
return shape{type, new_lens, new_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -32,6 +32,8 @@
#include <utility>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
......@@ -199,9 +201,12 @@ auto compute_op(rank<1>,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
-> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output_shape, input)),
input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -220,9 +225,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
-> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{
return x.compute(output_shape, input);
return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
}
template <class T>
......@@ -244,9 +249,11 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -278,9 +285,17 @@ auto compute_op(rank<4>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
F f) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
return x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
}
template <class T, class F>
......@@ -290,9 +305,11 @@ auto compute_op(rank<3>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
}
template <class T, class F>
......@@ -302,9 +319,10 @@ auto compute_op(rank<2>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{
return x.compute(output, inputs);
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......@@ -314,9 +332,12 @@ auto compute_op(rank<1>,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
F) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
}
template <class T, class F>
......@@ -348,7 +369,8 @@ auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
-> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input),
std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
......@@ -1066,7 +1088,7 @@ struct operation
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -1237,7 +1259,7 @@ struct operation
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......@@ -1276,7 +1298,7 @@ inline const ValueType& any_cast(const operation& x)
}
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline bool operator!=(const operation& x, const operation& y) { return not(x == y); }
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
......
......@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
......
......@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
/*!
* Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/
std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
std::vector<std::size_t> k_lens,
std::vector<std::size_t> strides,
std::vector<std::size_t> dilations,
bool use_upper = true);
std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& wei_lens,
const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& dilations,
bool use_upper);
// Used for dynamic auto padding of convolution operators since padding needs to be computed at
// evaulation time.
shape compute_padded_shape(const shape& input,
const shape& weights,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -238,7 +238,7 @@ struct pass
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <iostream>
......@@ -76,8 +77,8 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const;
std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const;
std::size_t size() const;
std::vector<shape> get_output_shapes() const;
......@@ -124,7 +125,7 @@ struct program
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
friend bool operator!=(const program& x, const program& y) { return not(x == y); }
// module related api
module* create_module(const std::string& name);
......
......@@ -147,7 +147,7 @@ struct raw_data : raw_data_base
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
return is_data_ptr<T>{} or
self->get_shape().type() == migraphx::shape::get_type<get_data_type<T>>{};
}
......@@ -232,7 +232,7 @@ auto visit_all(T&& x, Ts&&... xs)
{
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
if(not std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
}
......@@ -241,7 +241,7 @@ template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(
if(not std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
......@@ -281,7 +281,7 @@ template <class T,
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
return not(x == y);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
}
template <class T>
auto reflectable_impl(rank<1>, T&& x)
auto reflectable_impl(rank<1>, const T& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{});
auto reflectable_impl(rank<0>, const T&) -> decltype(std::false_type{});
template <class T>
struct remove_rvalue_reference
......@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template <class T>
auto reflect_tie(T& x)
{
return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })(
[](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
return reflect(x, [](auto&& y, auto&&...) {
// cppcheck-suppress UnnecessaryElseStatement
if constexpr(is_reflectable<decltype(y)>{})
{
auto t = reflect_tie(y);
return detail::wrap<decltype(t)>(t);
}
else
{
return detail::wrap<decltype(y)>(y);
}
})([](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
}
template <class T, class F>
......@@ -129,7 +139,7 @@ template <class T>
struct reflect_equality
{
friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); }
friend bool operator!=(const T& x, const T& y) { return !(x == y); }
friend bool operator!=(const T& x, const T& y) { return not(x == y); }
};
template <class T>
......
......@@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
{
};
......
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
......@@ -34,11 +34,11 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite batchnorm to a multiply and add.
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct rewrite_batchnorm
struct rewrite_gelu
{
std::string name() const { return "rewrite_batchnorm"; }
std::string name() const { return "rewrite_gelu"; }
void apply(module& m) const;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment