Commit 7255bc66 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into onnx_parsing_squeeze_slice_concat

parents bc367f6b ad414ba9
...@@ -368,6 +368,17 @@ struct test_add_relu ...@@ -368,6 +368,17 @@ struct test_add_relu
} }
}; };
struct test_leaky_relu
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::leaky_relu{0.01}, x);
return p;
}
};
struct test_conv_pooling struct test_conv_pooling
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -566,6 +577,40 @@ struct test_conv_bn_relu_pooling ...@@ -566,6 +577,40 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat2
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2
{ {
static migraph::instruction_ref static migraph::instruction_ref
...@@ -604,6 +649,8 @@ struct test_conv_bn_relu_pooling2 ...@@ -604,6 +649,8 @@ struct test_conv_bn_relu_pooling2
int main() int main()
{ {
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_add>(); verify_program<test_add>();
verify_program<test_triadd>(); verify_program<test_triadd>();
verify_program<test_triadd2>(); verify_program<test_triadd2>();
...@@ -619,6 +666,7 @@ int main() ...@@ -619,6 +666,7 @@ int main()
verify_program<test_conv2>(); verify_program<test_conv2>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
// verify_program<test_gemm_ld>(); // verify_program<test_gemm_ld>();
......
leaky_relu-example:R
"
01" LeakyRelu*
alpha
#<
test-modelZ
0

b
1

B
\ No newline at end of file
...@@ -54,7 +54,7 @@ void pytorch_conv_bn_relu_maxpool() ...@@ -54,7 +54,7 @@ void pytorch_conv_bn_relu_maxpool()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6); auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
...@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog); EXPECT(p == prog);
} }
void leaky_relu_test()
{
migraph::program p;
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
pytorch_conv_relu_maxpool(); pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool(); pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2(); pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
} }
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
struct simple_operation struct simple_operation
{ {
template <class T, class F>
static auto reflect(T& x, F f)
{
return migraph::pack(f(x.data, "data"));
}
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraph::shape compute_shape(const std::vector<migraph::shape>&) const
...@@ -19,7 +24,7 @@ struct simple_operation ...@@ -19,7 +24,7 @@ struct simple_operation
} }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
os << "[" << op.name() << "]"; os << op.name() << "[" << op.data << "]";
return os; return os;
} }
}; };
...@@ -44,9 +49,23 @@ void operation_copy_test() ...@@ -44,9 +49,23 @@ void operation_copy_test()
migraph::operation op1 = s; // NOLINT migraph::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT migraph::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(s.name() == op1.name()); EXPECT(s == op1);
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(op2.name() == op1.name()); EXPECT(op2 == op1);
}
void operation_equal_test()
{
simple_operation s{};
migraph::operation op1 = s;
s.data = 2;
migraph::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT
EXPECT(s != op1);
EXPECT(op2 == op1);
EXPECT(op3 != op2);
EXPECT(op3 != op1);
} }
struct not_operation struct not_operation
...@@ -70,7 +89,7 @@ void operation_print() ...@@ -70,7 +89,7 @@ void operation_print()
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
EXPECT(s == "[simple]"); EXPECT(s == "simple[1]");
} }
void operation_default_print() void operation_default_print()
...@@ -85,6 +104,7 @@ void operation_default_print() ...@@ -85,6 +104,7 @@ void operation_default_print()
int main() int main()
{ {
operation_copy_test(); operation_copy_test();
operation_equal_test();
operation_any_cast(); operation_any_cast();
operation_print(); operation_print();
operation_default_print(); operation_default_print();
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp> #include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -54,11 +55,34 @@ namespace operation_stream { ...@@ -54,11 +55,34 @@ namespace operation_stream {
template <class T> template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{ {
return os << x.name(); os << x.name();
char delim = '[';
reflect_each(x, [&](auto& y, auto name) {
os << delim;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
os << "]";
return os;
} }
} // namespace operation_stream } // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<1>,
const T& x, const T& x,
...@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
} }
<% <%
interface('operation', interface(
virtual('name', returns='std::string', const=True), 'operation',
virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True), virtual('name', returns = 'std::string', const = True),
virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<') virtual('compute',
) returns = 'argument',
%> ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
op = 'const operation &',
using = 'migraph::operation_stream::operator<<'),
friend('operator==',
returns = 'bool',
x = 'const operation &',
y = 'const operation &',
using = 'migraph::operation_equal::operator==')) %>
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
}
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment