Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v) ...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return make_op_from_value(name, v); return make_op_from_value(name, v);
} }
operation make_json_op(const std::string& name, const std::string& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -181,7 +181,7 @@ struct marker ...@@ -181,7 +181,7 @@ struct marker
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -233,7 +233,7 @@ struct marker ...@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -38,11 +38,11 @@ struct gelu_erf_matcher ...@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f; F f;
auto erf_fn() const auto erf_fn() const
{ {
return f("erf")( auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
used_once(), has_value(M_SQRT1_2, 1e-3)));
arg(0)(used_once(), auto div_sqrt_2 =
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
has_value(M_SQRT1_2, 1e-3))))); return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
} }
auto add_erf() const auto add_erf() const
......
...@@ -50,8 +50,8 @@ struct layernorm_matcher ...@@ -50,8 +50,8 @@ struct layernorm_matcher
{ {
return f("div")(arg(0)(x_minus_mean()), return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")( arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name) ...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -224,7 +224,7 @@ struct module ...@@ -224,7 +224,7 @@ struct module
friend std::ostream& operator<<(std::ostream& os, const module& m); friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y); friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private: private:
void assign(const module& m); void assign(const module& m);
......
...@@ -35,17 +35,13 @@ struct onnx_options ...@@ -35,17 +35,13 @@ struct onnx_options
{ {
/// Old way to set default fixed dimension size /// Old way to set default fixed dimension size
std::size_t default_dim_value = 0; std::size_t default_dim_value = 0;
/*! /// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value /// parser throws)
* set parser throws)
*/
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*! /// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
* Explicitly specify dynamic dims of an input (if both map_input_dims and /// set parser throws)
* map_dyn_input_dims set parser throws)
*/
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {}; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
...@@ -53,6 +49,8 @@ struct onnx_options ...@@ -53,6 +49,8 @@ struct onnx_options
bool print_program_on_error = false; bool print_program_on_error = false;
/// Max iter num for the loop operator /// Max iter num for the loop operator
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
/// Use dynamic output for operators when available
bool use_dyn_output = false;
}; };
/// Create a program from an onnx file /// Create a program from an onnx file
......
...@@ -70,7 +70,7 @@ struct broadcast ...@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
} }
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
......
...@@ -37,7 +37,9 @@ enum padding_mode_t ...@@ -37,7 +37,9 @@ enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same, same,
valid valid,
same_lower,
same_upper
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling. // The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
......
...@@ -86,7 +86,7 @@ struct concat ...@@ -86,7 +86,7 @@ struct concat
{ {
if(l != axis) if(l != axis)
{ {
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l]; return s.lens()[l] == first_shape_lens[l];
})) }))
{ {
......
...@@ -45,7 +45,15 @@ struct convert : unary<convert> ...@@ -45,7 +45,15 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; auto input = inputs.at(0);
if(input.dynamic())
{
return {target_type, input.dyn_dims()};
}
else
{
return {target_type, input.lens(), input.strides()};
}
} }
std::string point_op() const std::string point_op() const
......
...@@ -41,8 +41,9 @@ struct convolution ...@@ -41,8 +41,9 @@ struct convolution
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<std::size_t> dilation = {1, 1};
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -51,7 +52,8 @@ struct convolution ...@@ -51,7 +52,8 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode")); f(self.padding_mode, "padding_mode"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
...@@ -69,43 +71,137 @@ struct convolution ...@@ -69,43 +71,137 @@ struct convolution
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this, true}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // num of dims of input and attribute should match
auto input_size = inputs[0].lens().size(); const auto input_size = inputs[0].max_lens().size();
auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
const shape& input = inputs.at(0); const shape& x_shape = inputs.at(0);
const shape& weights = inputs.at(1); const shape& w_shape = inputs.at(1);
size_t kdims = input_size - 2; const size_t num_spatial_dims = input_size - 2;
if(kdims != this->kdims()) if(num_spatial_dims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
} }
if(input.lens().at(1) != (weights.lens().at(1) * group)) if(not x_shape.dynamic() and not w_shape.dynamic() and
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic())
{
return dynamic_compute_shape(x_shape, w_shape);
}
else
{
return fixed_compute_shape(x_shape, w_shape);
}
}
std::vector<std::size_t> calc_conv_lens(std::vector<std::size_t> x_lens,
std::vector<std::size_t> w_lens) const
{
const size_t num_spatial_dims = x_lens.size() - 2;
std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
{
if(x_lens[i] == 0 or w_lens[i] == 0)
{
// for handling when a dimension = 0 (opt of dynamic_dimension)
ret.push_back(0);
}
else
{
auto padding_factor = 2 * padding[i];
if(padding.size() == 2 * num_spatial_dims)
{
// when padding is {x0_begin, x1_begin, ... x0_end , x1_end, ...}
padding_factor = padding[i] + padding[i + num_spatial_dims];
}
ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(x_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
stride[i] +
1)));
}
}
return ret;
}
for(size_t i = 0; i < kdims; i++) shape dynamic_compute_shape(shape x_shape, shape w_shape) const
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
auto dynamic_shape_push_back = [&](const shape& input_shape) {
if(input_shape.dynamic())
{
output_dyn_dims.push_back(input_shape.dyn_dims().at(0));
}
else
{
auto l = input_shape.lens().at(0);
output_dyn_dims.push_back({l, l, 0});
}
};
dynamic_shape_push_back(x_shape);
dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad)
{ {
auto padding_factor = 2 * padding[i]; for(std::size_t i = 0; i < num_spatial_dims; ++i)
if(padding_size == 2 * kdims) {
padding_factor = padding[i] + padding[i + kdims]; auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; };
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( auto s = stride[i];
1, if(x_shape.dynamic())
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + {
padding_factor) / auto x = x_shape.dyn_dims()[i + 2];
stride[i] + output_dyn_dims.push_back(shape::dynamic_dimension{
1))); ceil_div(x.min, s), ceil_div(x.max, s), ceil_div(x.opt, s)});
}
else
{
auto od = ceil_div(x_shape.lens()[i + 2], s);
output_dyn_dims.push_back(shape::dynamic_dimension{od, od, 0});
}
}
} }
else
{
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
}
return shape{x_shape.type(), output_dyn_dims};
}
return inputs[0].with_lens(output_lens); shape fixed_compute_shape(shape x_shape, shape w_shape) const
{
std::vector<size_t> output_lens{x_shape.lens()[0], w_shape.lens()[0]};
auto spatial_lens = calc_conv_lens(x_shape.lens(), w_shape.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x);
});
return x_shape.with_lens(output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -43,13 +43,14 @@ struct dot ...@@ -43,13 +43,14 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; })) if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
} }
// only handle the case that the batch size of a and b are the same // only handle the case that the batch size of a and b are the same
if(!std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
......
...@@ -21,21 +21,32 @@ ...@@ -21,21 +21,32 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_ACOS_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_RTGLIB_ACOS_HPP #define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/gpu/device/acos.hpp> #include <cmath>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace op {
struct hip_acos : unary_device<hip_acos, device::acos> struct fmod : binary<fmod>
{ {
std::string name() const { return "fmod"; }
value attributes() const
{
auto a = base_attributes();
a["commutative"] = false;
return a;
}
auto apply() const
{
return [](auto x, auto y) { return std::fmod(x, y); };
}
}; };
} // namespace gpu } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -65,7 +65,7 @@ struct gather ...@@ -65,7 +65,7 @@ struct gather
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis); lens.erase(lens.begin() + axis);
if(!inputs[1].scalar()) if(not inputs[1].scalar())
{ {
auto ind_lens = inputs[1].lens(); auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end()); lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
......
...@@ -21,21 +21,33 @@ ...@@ -21,21 +21,33 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP #define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/gpu/device/acosh.hpp> #include <cmath>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace op {
struct hip_acosh : unary_device<hip_acosh, device::acosh> struct mod : binary<mod>
{ {
std::string name() const { return "mod"; }
value attributes() const
{
auto a = base_attributes();
a["commutative"] = false;
a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})";
return a;
}
auto apply() const
{
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
}
}; };
} // namespace gpu } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -45,11 +45,13 @@ namespace op { ...@@ -45,11 +45,13 @@ namespace op {
struct nonmaxsuppression struct nonmaxsuppression
{ {
bool center_point_box = false; bool center_point_box = false;
bool use_dyn_output = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.center_point_box, "center_point_box")); return pack(f(self.center_point_box, "center_point_box"),
f(self.use_dyn_output, "use_dyn_output"));
} }
std::string name() const { return "nonmaxsuppression"; } std::string name() const { return "nonmaxsuppression"; }
...@@ -57,27 +59,81 @@ struct nonmaxsuppression ...@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this, true}.only_dims(3).same_ndims();
auto lens = inputs.front().lens(); auto boxes_max_lens = inputs.at(0).max_lens();
// num batches * num boxes
const auto max_num_boxes = boxes_max_lens.at(0) * boxes_max_lens.at(1);
// check input shape auto fixed_shape_error_check = [&]() {
if(lens[1] != inputs.at(1).lens()[2]) auto lens = inputs.front().lens();
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
}
};
if(use_dyn_output)
{ {
MIGRAPHX_THROW( if(inputs.at(0).dynamic())
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"); {
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const auto boxes_dims = inputs.at(0).dyn_dims();
const auto scores_dims = inputs.at(1).dyn_dims();
if(boxes_dims.at(1) != scores_dims.at(2))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input");
}
if(boxes_dims.at(0) != scores_dims.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input");
}
}
else if(inputs.at(1).dynamic())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const auto scores_dims = inputs.at(1).dyn_dims();
const auto boxes_lens = inputs.at(0).lens();
if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max != boxes_lens.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched");
}
if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max != boxes_lens.at(1))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches");
}
}
else
{
fixed_shape_error_check();
}
std::vector<shape::dynamic_dimension> out_lens = {};
out_lens.push_back({0, max_num_boxes, 0});
out_lens.push_back({3, 3, 0});
return {shape::int64_type, out_lens};
} }
else
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{ {
MIGRAPHX_THROW( if(inputs.at(0).dynamic() or inputs.at(1).dynamic())
"NonMaxSuppression: number of batches mismatch between boxes and scores input"); {
MIGRAPHX_THROW(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false");
}
fixed_shape_error_check();
std::vector<std::size_t> out_lens = {max_num_boxes, 3};
return {shape::int64_type, out_lens};
} }
std::vector<int64_t> out_lens(2);
out_lens.at(0) = lens.at(1);
out_lens.at(1) = 3;
return {shape::int64_type, out_lens};
} }
struct box struct box
...@@ -181,13 +237,13 @@ struct nonmaxsuppression ...@@ -181,13 +237,13 @@ struct nonmaxsuppression
} }
template <class Output, class Boxes, class Scores> template <class Output, class Boxes, class Scores>
void compute_nms(Output output, std::size_t compute_nms(Output output,
Boxes boxes, Boxes boxes,
Scores scores, Scores scores,
const shape& output_shape, const shape& max_output_shape,
std::size_t max_output_boxes_per_class, std::size_t max_output_boxes_per_class,
double iou_threshold, double iou_threshold,
double score_threshold) const double score_threshold) const
{ {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
...@@ -197,7 +253,7 @@ struct nonmaxsuppression ...@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
...@@ -210,7 +266,7 @@ struct nonmaxsuppression ...@@ -210,7 +266,7 @@ struct nonmaxsuppression
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold); auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!boxes_heap.empty() && while(not boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
// Check with existing selected boxes for this class, remove box if it // Check with existing selected boxes for this class, remove box if it
...@@ -237,11 +293,14 @@ struct nonmaxsuppression ...@@ -237,11 +293,14 @@ struct nonmaxsuppression
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
return selected_indices.size() / 3;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; // make buffer of maximum size
shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape};
std::size_t max_output_boxes_per_class = std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
...@@ -249,22 +308,29 @@ struct nonmaxsuppression ...@@ -249,22 +308,29 @@ struct nonmaxsuppression
{ {
return result; return result;
} }
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f; double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
std::size_t num_selected = 0;
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output, num_selected = compute_nms(output,
boxes, boxes,
scores, scores,
output_shape, max_output_shape,
max_output_boxes_per_class, max_output_boxes_per_class,
iou_threshold, iou_threshold,
score_threshold); score_threshold);
}); });
}); });
if(use_dyn_output)
return result; {
return result.reshape({output_shape.type(), {num_selected, 3}});
}
else
{
return result;
}
} }
}; };
......
...@@ -41,8 +41,9 @@ struct quant_convolution ...@@ -41,8 +41,9 @@ struct quant_convolution
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -51,7 +52,8 @@ struct quant_convolution ...@@ -51,7 +52,8 @@ struct quant_convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.group, "group")); f(self.group, "group"),
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
value attributes() const value attributes() const
......
...@@ -49,13 +49,14 @@ struct quant_dot ...@@ -49,13 +49,14 @@ struct quant_dot
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
} }
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; })) if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
} }
// only handle the case that the batch size of a and b are the same // only handle the case that the batch size of a and b are the same
if(!std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" + MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
......
...@@ -78,7 +78,7 @@ struct slice ...@@ -78,7 +78,7 @@ struct slice
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides(); const std::vector<std::size_t>& strides = s.strides();
auto offset = 0; auto offset = 0;
if(!axes.empty()) if(not axes.empty())
{ {
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
...@@ -109,7 +109,7 @@ struct slice ...@@ -109,7 +109,7 @@ struct slice
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
} }
if(starts.size() != axes.size() || axes.size() != ends.size()) if(starts.size() != axes.size() or axes.size() != ends.size())
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); MIGRAPHX_THROW("SLICE: inconsistent sizes");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment