Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
...@@ -15,6 +17,15 @@ enum padding_mode_t ...@@ -15,6 +17,15 @@ enum padding_mode_t
valid valid
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max,
lpnorm
};
// indicate rnn computation direction // indicate rnn computation direction
enum class rnn_direction enum class rnn_direction
{ {
...@@ -23,6 +34,7 @@ enum class rnn_direction ...@@ -23,6 +34,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
......
...@@ -97,7 +97,6 @@ struct deconvolution ...@@ -97,7 +97,6 @@ struct deconvolution
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) { par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0]; const int w = idx_win[0];
...@@ -140,9 +139,7 @@ struct deconvolution ...@@ -140,9 +139,7 @@ struct deconvolution
weights(idx_wei.begin(), idx_wei.end()); weights(idx_wei.begin(), idx_wei.end());
} }
}); });
}); });
}); });
return result; return result;
} }
......
...@@ -51,7 +51,6 @@ struct flatten ...@@ -51,7 +51,6 @@ struct flatten
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
This diff is collapsed.
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -69,7 +69,6 @@ struct multibroadcast ...@@ -69,7 +69,6 @@ struct multibroadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -181,7 +181,8 @@ struct nonmaxsuppression ...@@ -181,7 +181,8 @@ struct nonmaxsuppression
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); }); make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0; int64_t box_idx = 0;
transform_if(scores.begin() + score_offset, transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num, scores.begin() + score_offset + box_num,
insert_to_sorted_boxes, insert_to_sorted_boxes,
[&](auto sc) { [&](auto sc) {
......
This diff is collapsed.
...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived> ...@@ -38,13 +38,33 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.at(0); auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
} }
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
{ {
argument result = args[0].copy(); result = args[0].copy();
auto s = result.get_shape(); }
else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens(); auto lens = s.lens();
lens[axis] = 1; lens[axis] = 1;
......
This diff is collapsed.
This diff is collapsed.
...@@ -40,7 +40,6 @@ struct scalar ...@@ -40,7 +40,6 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment