Commit 7255bc66 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into onnx_parsing_squeeze_slice_concat

parents bc367f6b ad414ba9
......@@ -368,6 +368,17 @@ struct test_add_relu
}
};
struct test_leaky_relu
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::leaky_relu{0.01}, x);
return p;
}
};
struct test_conv_pooling
{
migraph::program create_program() const
......@@ -566,6 +577,40 @@ struct test_conv_bn_relu_pooling
}
};
struct test_concat
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat2
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_conv_bn_relu_pooling2
{
static migraph::instruction_ref
......@@ -604,6 +649,8 @@ struct test_conv_bn_relu_pooling2
int main()
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_add>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
......@@ -619,6 +666,7 @@ int main()
verify_program<test_conv2>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
verify_program<test_gemm>();
// verify_program<test_gemm_ld>();
......
leaky_relu-example:R
"
01" LeakyRelu*
alpha
#<
test-modelZ
0

b
1

B
\ No newline at end of file
......@@ -54,8 +54,8 @@ void pytorch_conv_bn_relu_maxpool()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx");
......@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog);
}
void leaky_relu_test()
{
migraph::program p;
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
}
......@@ -6,6 +6,11 @@
struct simple_operation
{
template <class T, class F>
static auto reflect(T& x, F f)
{
return migraph::pack(f(x.data, "data"));
}
int data = 1;
std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const
......@@ -19,7 +24,7 @@ struct simple_operation
}
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << "[" << op.name() << "]";
os << op.name() << "[" << op.data << "]";
return os;
}
};
......@@ -44,9 +49,23 @@ void operation_copy_test()
migraph::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression
EXPECT(s.name() == op1.name());
EXPECT(s == op1);
// cppcheck-suppress duplicateExpression
EXPECT(op2.name() == op1.name());
EXPECT(op2 == op1);
}
void operation_equal_test()
{
simple_operation s{};
migraph::operation op1 = s;
s.data = 2;
migraph::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT
EXPECT(s != op1);
EXPECT(op2 == op1);
EXPECT(op3 != op2);
EXPECT(op3 != op1);
}
struct not_operation
......@@ -70,7 +89,7 @@ void operation_print()
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "[simple]");
EXPECT(s == "simple[1]");
}
void operation_default_print()
......@@ -85,6 +104,7 @@ void operation_default_print()
int main()
{
operation_copy_test();
operation_equal_test();
operation_any_cast();
operation_print();
operation_default_print();
......
......@@ -8,7 +8,8 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -54,11 +55,34 @@ namespace operation_stream {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
return os << x.name();
os << x.name();
char delim = '[';
reflect_each(x, [&](auto& y, auto name) {
os << delim;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
os << "]";
return os;
}
} // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T>
auto compute_op(rank<1>,
const T& x,
......@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
}
<%
interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True),
virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
)
%>
interface(
'operation',
virtual('name', returns = 'std::string', const = True),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute',
returns = 'argument',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
op = 'const operation &',
using = 'migraph::operation_stream::operator<<'),
friend('operator==',
returns = 'bool',
x = 'const operation &',
y = 'const operation &',
using = 'migraph::operation_equal::operator==')) %>
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment