Unverified Commit e67aa78c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Dropout change for two outputs (#626)



* add support for latest dropout version

* clang format

* fix a build error

* fix a cppcheck error

* add bool type

* code backup

* code backup

* clang format

* fix build warnings

* clang format

* add the equal operator

* add the equal operator

* clang format

* remove unnecessary code

* refine unit tests

* clang format

* fix review comments and a bug

* clang format

* additional changes

* clang format

* remove unnecessary changes

* remove unnecessary changes
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent eacad500
......@@ -86,7 +86,6 @@ struct onnx_parser
add_generic_op("Concat", "concat");
add_generic_op("Cos", "cos");
add_generic_op("Cosh", "cosh");
add_generic_op("Dropout", "identity");
add_generic_op("Erf", "erf");
add_generic_op("Exp", "exp");
add_generic_op("Flatten", "flatten");
......@@ -134,6 +133,7 @@ struct onnx_parser
add_mem_op("Conv", "convolution", &onnx_parser::parse_conv);
add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("Dropout", &onnx_parser::parse_dropout);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Equal", &onnx_parser::parse_equal);
add_mem_op("Expand", &onnx_parser::parse_expand);
......@@ -2375,6 +2375,18 @@ struct onnx_parser
MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
}
std::vector<instruction_ref>
parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto out = prog.add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()};
auto mask = prog.add_literal(literal(mask_s, vec));
return {out, mask};
}
template <class T>
std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
{
......
......@@ -750,10 +750,13 @@ TEST_CASE(dropout_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraphx::op::identity{}, input);
auto prog = optimize_onnx("dropout_test.onnx");
auto out = p.add_instruction(migraphx::op::identity{}, input);
migraphx::shape s{migraphx::shape::bool_type, {1, 3, 2, 2}};
std::vector<int8_t> vec(s.elements(), 1);
p.add_literal(migraphx::literal(s, vec));
p.add_return({out});
auto prog = migraphx::parse_onnx("dropout_test.onnx");
EXPECT(p == prog);
}
......
......@@ -187,8 +187,6 @@ def create_backend_test(testname=None, target_device=None):
)
backend_test.exclude(
r'test_argmin_no_keepdims_example_select_last_index_cpu')
backend_test.exclude(r'test_dropout_default_mask_cpu')
backend_test.exclude(r'test_dropout_default_mask_ratio_cpu')
backend_test.exclude(r'test_logsoftmax_axis_0_cpu')
backend_test.exclude(r'test_logsoftmax_axis_1_cpu')
backend_test.exclude(r'test_logsoftmax_default_axis_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment