Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
......@@ -27,6 +27,8 @@
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return make_op_from_value(name, v);
}
operation make_json_op(const std::string& name, const std::string& s);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -181,7 +181,7 @@ struct marker
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f;
auto erf_fn() const
{
return f("erf")(
used_once(),
arg(0)(used_once(),
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)))));
auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)));
auto div_sqrt_2 =
f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
}
auto add_erf() const
......
......@@ -50,8 +50,8 @@ struct layernorm_matcher
{
return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f))))))));
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
}
auto matcher() const { return layernorm_onnx(); }
......
......@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt;
}
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms>
auto skip(Ms... ms)
{
......@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
return match::has_attribute("pointwise")(ms...);
}
} // namespace match
......
......@@ -224,7 +224,7 @@ struct module
friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); }
friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private:
void assign(const module& m);
......
......@@ -35,17 +35,13 @@ struct onnx_options
{
/// Old way to set default fixed dimension size
std::size_t default_dim_value = 0;
/*!
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value
* set parser throws)
*/
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*!
* Explicitly specify dynamic dims of an input (if both map_input_dims and
* map_dyn_input_dims set parser throws)
*/
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// set parser throws)
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
......@@ -53,6 +49,8 @@ struct onnx_options
bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
/// Use dynamic output for operators when available
bool use_dyn_output = false;
};
/// Create a program from an onnx file
......
......@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
}
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
......
......@@ -37,7 +37,9 @@ enum padding_mode_t
{
default_, // NOLINT
same,
valid
valid,
same_lower,
same_upper
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
......
......@@ -86,7 +86,7 @@ struct concat
{
if(l != axis)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
}))
{
......
......@@ -45,7 +45,15 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
auto input = inputs.at(0);
if(input.dynamic())
{
return {target_type, input.dyn_dims()};
}
else
{
return {target_type, input.lens(), input.strides()};
}
}
std::string point_op() const
......
This diff is collapsed.
......@@ -43,13 +43,14 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
......
......@@ -65,7 +65,7 @@ struct gather
auto lens = inputs[0].lens();
auto type = inputs[0].type();
lens.erase(lens.begin() + axis);
if(!inputs[1].scalar())
if(not inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment