Commit 72011beb authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-concat-pointwise

parents d48d9bf7 d37a4df9
......@@ -107,7 +107,7 @@
<summary>Use make_shared or make_unique instead of new</summary>
</message>
</rule>
<!-- <rule>
<rule>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[ \|\| ]]></pattern>
<message>
......@@ -124,7 +124,7 @@
<severity>style</severity>
<summary>Use 'not' instead of !</summary>
</message>
</rule> -->
</rule>
<!-- <rule>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
......
......@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx::program p;
if(cmdOptionExists(argv + 2, argv + argc, "--parse") ||
!cmdOptionExists(argv + 2, argv + argc, "--load"))
if(cmdOptionExists(argv + 2, argv + argc, "--parse") or
not cmdOptionExists(argv + 2, argv + argc, "--load"))
{
std::cout << "Parsing ONNX File" << std::endl;
migraphx::onnx_options options;
......
......@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{
throw std::runtime_error("sscal_custom_op must have 2 input arguments");
}
if(inputs[0].lengths().size() != 1 || inputs[0].lengths()[0] != 1)
if(inputs[0].lengths().size() != 1 or inputs[0].lengths()[0] != 1)
{
throw std::runtime_error("first input argument to sscal_custom_op must be a scalar");
}
......
......@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char** begin = argv + 1;
char** end = argv + argc;
const bool CPU = (std::find(begin, end, std::string("-c")) != end) ||
const bool CPU = (std::find(begin, end, std::string("-c")) != end) or
std::find(begin, end, std::string("--cpu")) != end;
const bool GPU = std::find(begin, end, std::string("-g")) != end ||
const bool GPU = std::find(begin, end, std::string("-g")) != end or
std::find(begin, end, std::string("--gpu")) != end;
const bool FP16 = std::find(begin, end, std::string("-f")) != end ||
const bool FP16 = std::find(begin, end, std::string("-f")) != end or
std::find(begin, end, std::string("--fp16")) != end;
const bool INT8 = std::find(begin, end, std::string("-i")) != end ||
const bool INT8 = std::find(begin, end, std::string("-i")) != end or
std::find(begin, end, std::string("--int8")) != end;
const bool CALIB = std::find(begin, end, std::string("--cal")) != end;
const bool PRINT = std::find(begin, end, std::string("-p")) != end ||
const bool PRINT = std::find(begin, end, std::string("-p")) != end or
std::find(begin, end, std::string("--print")) != end;
migraphx::program prog;
......@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const int HEIGHT = 28;
const int WIDTH = 28;
if(!file.is_open())
if(not file.is_open())
{
return;
}
......
......@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp
sqlite.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
......
......@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); }
friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
};
/**
......@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout;
}
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); }
friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
};
/// A target for compilation
......@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const
{
std::vector<const char*> result(this->size());
if(!result.empty())
if(not result.empty())
{
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
}
......@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
friend bool operator!=(const program& px, const program& py) { return not(px == py); }
};
// options for migraphx file format options
......
......@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
if(not float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
......@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
if(axis_index == 0 ||
if(axis_index == 0 or
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
......
......@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes, mods))
if(not try_compute_shape(output, input_shapes, mods))
{
return false;
}
......
......@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is.seekg(0, std::ios::beg);
T buffer(size, 0);
if(!is.read(&buffer[0], size))
if(not is.read(&buffer[0], size))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
......
......@@ -205,7 +205,7 @@ struct allocation_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -101,7 +101,7 @@ struct check_shapes
const check_shapes& nelements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this;
}
......@@ -164,7 +164,7 @@ struct check_shapes
*/
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
if(not this->same([](const shape& s) { return s; }))
MIGRAPHX_THROW(prefix() + "Shapes do not match");
return *this;
}
......@@ -174,7 +174,7 @@ struct check_shapes
*/
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
if(not this->same([](const shape& s) { return s.type(); }))
MIGRAPHX_THROW(prefix() + "Types do not match");
return *this;
}
......@@ -184,10 +184,10 @@ struct check_shapes
*/
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.max_lens(); }))
if(not this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); }))
if(not this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this;
}
......@@ -197,7 +197,7 @@ struct check_shapes
*/
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.max_lens().size(); }))
if(not this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......@@ -207,7 +207,7 @@ struct check_shapes
*/
const check_shapes& standard() const
{
if(!this->all_of([](const shape& s) { return s.standard(); }))
if(not this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
return *this;
}
......@@ -217,7 +217,7 @@ struct check_shapes
*/
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
......@@ -227,7 +227,7 @@ struct check_shapes
*/
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
if(not this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed");
return *this;
}
......@@ -237,7 +237,7 @@ struct check_shapes
*/
const check_shapes& packed_or_broadcasted() const
{
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
return *this;
}
......@@ -247,7 +247,7 @@ struct check_shapes
*/
const check_shapes& tuple_type() const
{
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this;
}
......@@ -257,7 +257,7 @@ struct check_shapes
*/
const check_shapes& not_transposed() const
{
if(!this->all_of([](const shape& s) { return not s.transposed(); }))
if(not this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are transposed");
return *this;
}
......@@ -267,7 +267,7 @@ struct check_shapes
*/
const check_shapes& not_broadcasted() const
{
if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
return *this;
}
......@@ -278,7 +278,7 @@ struct check_shapes
*/
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this;
}
......@@ -288,7 +288,8 @@ struct check_shapes
*/
const check_shapes& batch_not_transposed() const
{
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
if(not this->all_of(
[&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
MIGRAPHX_THROW(prefix() + "Batch size is transposed");
return *this;
}
......
......@@ -183,7 +183,7 @@ struct concat_optimization
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -246,7 +246,7 @@ struct context
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -31,9 +31,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable())
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(not it._M_dereferenceable())
{
return !it._M_dereferenceable();
return not it._M_dereferenceable();
}
template <class Iterator, class EndIterator>
......
......@@ -181,7 +181,7 @@ struct marker
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
......@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f;
auto erf_fn() const
{
return f("erf")(
used_once(),
arg(0)(used_once(),
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)))));
auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)));
auto div_sqrt_2 =
f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
}
auto add_erf() const
......
......@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt;
}
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms>
auto skip(Ms... ms)
{
......@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
return match::has_attribute("pointwise")(ms...);
}
} // namespace match
......
......@@ -219,7 +219,7 @@ struct module
friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); }
friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private:
void assign(const module& m);
......
......@@ -70,7 +70,7 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
}
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment