Unverified Commit ed5f9897 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #110 from ROCmSoftwarePlatform/dropout

Dropout
parents 6feca68d ddcaca24
...@@ -56,6 +56,8 @@ struct onnx_parser ...@@ -56,6 +56,8 @@ struct onnx_parser
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sub", op::sub{}); add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{}); add_generic_op("Sum", op::add{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
......
dropout-example:Y

01"Dropout test-dropoutZ
0




b
1




B
\ No newline at end of file
...@@ -158,6 +158,17 @@ void transpose_test() ...@@ -158,6 +158,17 @@ void transpose_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void dropout_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraph::op::identity{}, input);
auto prog = migraph::parse_onnx("dropout_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -169,4 +180,5 @@ int main() ...@@ -169,4 +180,5 @@ int main()
globalavgpool_test(); globalavgpool_test();
globalmaxpool_test(); globalmaxpool_test();
transpose_test(); transpose_test();
dropout_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment