Unverified Commit 866cca5b authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Neg operator (#557)

* add the neg operator

* clang format

* add missing operator

* fixed a cppcheck error

* change to use the neg operator

* clang format
parent dd6523c9
...@@ -61,9 +61,10 @@ struct onnx_parser ...@@ -61,9 +61,10 @@ struct onnx_parser
add_generic_op("Erf", op::erf{}); add_generic_op("Erf", op::erf{});
add_generic_op("Exp", op::exp{}); add_generic_op("Exp", op::exp{});
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_generic_op("Log", op::log{});
add_generic_op("Floor", op::floor{}); add_generic_op("Floor", op::floor{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Log", op::log{});
add_generic_op("Neg", op::neg{});
add_generic_op("Reciprocal", op::recip{}); add_generic_op("Reciprocal", op::recip{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Round", op::round{}); add_generic_op("Round", op::round{});
......
...@@ -199,6 +199,7 @@ struct miopen_apply ...@@ -199,6 +199,7 @@ struct miopen_apply
add_quant_convolution_op(); add_quant_convolution_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
add_neg_op();
} }
void copy_params() void copy_params()
...@@ -448,6 +449,18 @@ struct miopen_apply ...@@ -448,6 +449,18 @@ struct miopen_apply
output); output);
}); });
} }
// use 0 - input to represent neg
void add_neg_op()
{
apply_map.emplace("neg", [=](instruction_ref ins) {
auto s = ins->get_shape();
std::vector<float> zeros(s.elements(), 0.0f);
auto l0 = prog->add_literal(literal(s, zeros));
auto output = insert_allocation(ins, s);
return prog->replace_instruction(ins, hip_sub{}, l0, ins->inputs().front(), output);
});
}
}; };
void lowering::apply(program& p) const { miopen_apply{&p, this}.apply(); } void lowering::apply(program& p) const { miopen_apply{&p, this}.apply(); }
......
...@@ -1673,6 +1673,23 @@ TEST_CASE(argmin_test_neg_1) ...@@ -1673,6 +1673,23 @@ TEST_CASE(argmin_test_neg_1)
EXPECT(migraphx::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify_range(result_vec, res_gold));
} }
TEST_CASE(neg_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 1.3f, -1.2f, 0.0f, -100.f, 200.f};
auto input = p.add_literal(migraphx::literal(s, data));
auto ret = p.add_instruction(migraphx::op::neg{}, input);
p.add_return({ret});
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.0f, -1.3f, 1.2f, 0.0f, 100.f, -200.f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -2551,4 +2551,17 @@ struct test_recip : verify_program<test_recip> ...@@ -2551,4 +2551,17 @@ struct test_recip : verify_program<test_recip>
} }
}; };
struct test_neg : verify_program<test_neg>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input = p.add_parameter("x", s);
p.add_instruction(migraphx::op::neg{}, input);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1621,6 +1621,16 @@ def min_test(): ...@@ -1621,6 +1621,16 @@ def min_test():
return ([node], [a, b, c], [y]) return ([node], [a, b, c], [y])
@onnx_test
def neg_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('Neg', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test @onnx_test
def no_pad_test(): def no_pad_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
......
neg_test:A
01"Negneg_testZ
0


b
1


B
\ No newline at end of file
...@@ -1271,6 +1271,19 @@ TEST_CASE(no_pad_test) ...@@ -1271,6 +1271,19 @@ TEST_CASE(no_pad_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(neg_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto input = p.add_parameter("0", s);
auto ret = p.add_instruction(migraphx::op::neg{}, input);
p.add_return({ret});
auto prog = migraphx::parse_onnx("neg_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(onehot_test) TEST_CASE(onehot_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment