Commit 65b6a759 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-proto

parents ddb0c230 d78bcdfb
...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES) ...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.4)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
<summary>Use make_shared or make_unique instead of new</summary> <summary>Use make_shared or make_unique instead of new</summary>
</message> </message>
</rule> </rule>
<!-- <rule> <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[ \|\| ]]></pattern> <pattern><![CDATA[ \|\| ]]></pattern>
<message> <message>
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
<severity>style</severity> <severity>style</severity>
<summary>Use 'not' instead of !</summary> <summary>Use 'not' instead of !</summary>
</message> </message>
</rule> --> </rule>
<!-- <rule> <!-- <rule>
<tokenlist>raw</tokenlist> <tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern> <pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
......
...@@ -53,8 +53,8 @@ int main(int argc, char** argv) ...@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx::program p; migraphx::program p;
if(cmdOptionExists(argv + 2, argv + argc, "--parse") || if(cmdOptionExists(argv + 2, argv + argc, "--parse") or
!cmdOptionExists(argv + 2, argv + argc, "--load")) not cmdOptionExists(argv + 2, argv + argc, "--load"))
{ {
std::cout << "Parsing ONNX File" << std::endl; std::cout << "Parsing ONNX File" << std::endl;
migraphx::onnx_options options; migraphx::onnx_options options;
......
...@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
throw std::runtime_error("sscal_custom_op must have 2 input arguments"); throw std::runtime_error("sscal_custom_op must have 2 input arguments");
} }
if(inputs[0].lengths().size() != 1 || inputs[0].lengths()[0] != 1) if(inputs[0].lengths().size() != 1 or inputs[0].lengths()[0] != 1)
{ {
throw std::runtime_error("first input argument to sscal_custom_op must be a scalar"); throw std::runtime_error("first input argument to sscal_custom_op must be a scalar");
} }
......
...@@ -51,16 +51,16 @@ int main(int argc, char** argv) ...@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char** begin = argv + 1; char** begin = argv + 1;
char** end = argv + argc; char** end = argv + argc;
const bool CPU = (std::find(begin, end, std::string("-c")) != end) || const bool CPU = (std::find(begin, end, std::string("-c")) != end) or
std::find(begin, end, std::string("--cpu")) != end; std::find(begin, end, std::string("--cpu")) != end;
const bool GPU = std::find(begin, end, std::string("-g")) != end || const bool GPU = std::find(begin, end, std::string("-g")) != end or
std::find(begin, end, std::string("--gpu")) != end; std::find(begin, end, std::string("--gpu")) != end;
const bool FP16 = std::find(begin, end, std::string("-f")) != end || const bool FP16 = std::find(begin, end, std::string("-f")) != end or
std::find(begin, end, std::string("--fp16")) != end; std::find(begin, end, std::string("--fp16")) != end;
const bool INT8 = std::find(begin, end, std::string("-i")) != end || const bool INT8 = std::find(begin, end, std::string("-i")) != end or
std::find(begin, end, std::string("--int8")) != end; std::find(begin, end, std::string("--int8")) != end;
const bool CALIB = std::find(begin, end, std::string("--cal")) != end; const bool CALIB = std::find(begin, end, std::string("--cal")) != end;
const bool PRINT = std::find(begin, end, std::string("-p")) != end || const bool PRINT = std::find(begin, end, std::string("-p")) != end or
std::find(begin, end, std::string("--print")) != end; std::find(begin, end, std::string("--print")) != end;
migraphx::program prog; migraphx::program prog;
...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit) ...@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const int HEIGHT = 28; const int HEIGHT = 28;
const int WIDTH = 28; const int WIDTH = 28;
if(!file.is_open()) if(not file.is_open())
{ {
return; return;
} }
......
...@@ -82,6 +82,7 @@ add_library(migraphx ...@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
......
...@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); } friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
}; };
/** /**
...@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); } friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
}; };
/// A target for compilation /// A target for compilation
...@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const std::vector<const char*> names() const
{ {
std::vector<const char*> result(this->size()); std::vector<const char*> result(this->size());
if(!result.empty()) if(not result.empty())
{ {
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr()); call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
} }
...@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -1015,7 +1015,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return not(px == py); }
}; };
// options for migraphx file format options // options for migraphx file format options
......
...@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m, ...@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0]; auto a = args[0];
auto b = args[1]; auto b = args[1];
auto input_type = a->get_shape().type(); auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0)) if(not float_equal(alpha.at<float>(0), 1.0))
{ {
auto alpha_literal = m.add_literal(alpha); auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a}); a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
...@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const ...@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name()); std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
if(axis_index == 0 || if(axis_index == 0 or
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; })) std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{ {
// Last input should be an allocation // Last input should be an allocation
......
...@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(!try_compute_shape(output, input_shapes, mods)) if(not try_compute_shape(output, input_shapes, mods))
{ {
return false; return false;
} }
......
...@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename) ...@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is.seekg(0, std::ios::beg); is.seekg(0, std::ios::beg);
T buffer(size, 0); T buffer(size, 0);
if(!is.read(&buffer[0], size)) if(not is.read(&buffer[0], size))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
......
...@@ -205,7 +205,7 @@ struct allocation_model ...@@ -205,7 +205,7 @@ struct allocation_model
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -267,7 +267,7 @@ struct allocation_model ...@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -101,7 +101,7 @@ struct check_shapes ...@@ -101,7 +101,7 @@ struct check_shapes
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements"); MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this; return *this;
} }
...@@ -164,7 +164,7 @@ struct check_shapes ...@@ -164,7 +164,7 @@ struct check_shapes
*/ */
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(not this->same([](const shape& s) { return s; }))
MIGRAPHX_THROW(prefix() + "Shapes do not match"); MIGRAPHX_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
...@@ -174,7 +174,7 @@ struct check_shapes ...@@ -174,7 +174,7 @@ struct check_shapes
*/ */
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(not this->same([](const shape& s) { return s.type(); }))
MIGRAPHX_THROW(prefix() + "Types do not match"); MIGRAPHX_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
...@@ -184,10 +184,10 @@ struct check_shapes ...@@ -184,10 +184,10 @@ struct check_shapes
*/ */
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens(); })) if(not this->same([](const shape& s) { return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match"); MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); })) if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(!this->same([](const shape& s) { return s.min_lens(); })) if(not this->same([](const shape& s) { return s.min_lens(); }))
MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match"); MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
return *this; return *this;
} }
...@@ -197,7 +197,7 @@ struct check_shapes ...@@ -197,7 +197,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.max_lens().size(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
...@@ -207,7 +207,7 @@ struct check_shapes ...@@ -207,7 +207,7 @@ struct check_shapes
*/ */
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(not this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout"); MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
return *this; return *this;
} }
...@@ -217,7 +217,7 @@ struct check_shapes ...@@ -217,7 +217,7 @@ struct check_shapes
*/ */
const check_shapes& standard_or_scalar() const const check_shapes& standard_or_scalar() const
{ {
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); })) if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout"); MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this; return *this;
} }
...@@ -227,7 +227,7 @@ struct check_shapes ...@@ -227,7 +227,7 @@ struct check_shapes
*/ */
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(not this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed"); MIGRAPHX_THROW(prefix() + "Shapes are not packed");
return *this; return *this;
} }
...@@ -237,7 +237,7 @@ struct check_shapes ...@@ -237,7 +237,7 @@ struct check_shapes
*/ */
const check_shapes& packed_or_broadcasted() const const check_shapes& packed_or_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted"); MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
return *this; return *this;
} }
...@@ -247,7 +247,7 @@ struct check_shapes ...@@ -247,7 +247,7 @@ struct check_shapes
*/ */
const check_shapes& tuple_type() const const check_shapes& tuple_type() const
{ {
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
MIGRAPHX_THROW(prefix() + "Shapes are not tuple!"); MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
return *this; return *this;
} }
...@@ -257,7 +257,7 @@ struct check_shapes ...@@ -257,7 +257,7 @@ struct check_shapes
*/ */
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(not this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPHX_THROW(prefix() + "Shapes are transposed"); MIGRAPHX_THROW(prefix() + "Shapes are transposed");
return *this; return *this;
} }
...@@ -267,7 +267,7 @@ struct check_shapes ...@@ -267,7 +267,7 @@ struct check_shapes
*/ */
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
MIGRAPHX_THROW(prefix() + "Shapes are broadcasted"); MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
return *this; return *this;
} }
...@@ -278,7 +278,7 @@ struct check_shapes ...@@ -278,7 +278,7 @@ struct check_shapes
*/ */
const check_shapes& elements(std::size_t n) const const check_shapes& elements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements"); MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this; return *this;
} }
...@@ -288,7 +288,8 @@ struct check_shapes ...@@ -288,7 +288,8 @@ struct check_shapes
*/ */
const check_shapes& batch_not_transposed() const const check_shapes& batch_not_transposed() const
{ {
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) if(not this->all_of(
[&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
MIGRAPHX_THROW(prefix() + "Batch size is transposed"); MIGRAPHX_THROW(prefix() + "Batch size is transposed");
return *this; return *this;
} }
......
...@@ -183,7 +183,7 @@ struct concat_optimization ...@@ -183,7 +183,7 @@ struct concat_optimization
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -233,7 +233,7 @@ struct concat_optimization ...@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -246,7 +246,7 @@ struct context ...@@ -246,7 +246,7 @@ struct context
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -306,7 +306,7 @@ struct context ...@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -31,9 +31,9 @@ namespace migraphx { ...@@ -31,9 +31,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator> template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable()) auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(not it._M_dereferenceable())
{ {
return !it._M_dereferenceable(); return not it._M_dereferenceable();
} }
template <class Iterator, class EndIterator> template <class Iterator, class EndIterator>
......
...@@ -181,7 +181,7 @@ struct marker ...@@ -181,7 +181,7 @@ struct marker
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -233,7 +233,7 @@ struct marker ...@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -38,11 +38,11 @@ struct gelu_erf_matcher ...@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f; F f;
auto erf_fn() const auto erf_fn() const
{ {
return f("erf")( auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
used_once(), has_value(M_SQRT1_2, 1e-3)));
arg(0)(used_once(), auto div_sqrt_2 =
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
has_value(M_SQRT1_2, 1e-3))))); return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
} }
auto add_erf() const auto add_erf() const
......
...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name) ...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -219,7 +219,7 @@ struct module ...@@ -219,7 +219,7 @@ struct module
friend std::ostream& operator<<(std::ostream& os, const module& m); friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y); friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private: private:
void assign(const module& m); void assign(const module& m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment