Commit 433e75b0 authored by Khalique's avatar Khalique
Browse files

added add_bcast, implicit_bcast, unknown, softmax

parent 1181aa0c
...@@ -7,6 +7,24 @@ ...@@ -7,6 +7,24 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
/// Create a program from an onnx file /// Create a program from an onnx file
program parse_onnx(const std::string& name); program parse_onnx(const std::string& name);
......
...@@ -15,26 +15,10 @@ ...@@ -15,26 +15,10 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
struct onnx_parser struct onnx_parser
{ {
......
add_bcast-example:
-
0
12"Add*
axis*
broadcasttest-add_bcastZ
0




Z
1


b
2




B
\ No newline at end of file
implicit_bcast-example:q

0
12"Addtest-multi_bcastZ
0




Z
1


b
2




B
\ No newline at end of file
...@@ -301,6 +301,54 @@ void atan_test() ...@@ -301,6 +301,54 @@ void atan_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void add_bcast_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
EXPECT(p == prog);
}
void implicit_bcast_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
EXPECT(p == prog);
}
void unknown_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog);
}
void softmax_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -325,4 +373,7 @@ int main() ...@@ -325,4 +373,7 @@ int main()
asin_test(); asin_test();
acos_test(); acos_test();
atan_test(); atan_test();
add_bcast_test();
implicit_bcast_test();
unknown_test();
} }
softmax-example:I

01"Softmax test-softmaxZ
0


b
1


B
\ No newline at end of file
unknown-example:q

0
12"Unknown test-unknownZ
0




Z
1


b
2




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment