"vscode:/vscode.git/clone" did not exist on "dc8a84a0113d9134e936530f99d805a897f108bc"
Commit 433e75b0 authored by Khalique's avatar Khalique
Browse files

added add_bcast, implicit_bcast, unknown, softmax

parent 1181aa0c
......@@ -7,6 +7,24 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
/// Create a program from an onnx file
program parse_onnx(const std::string& name);
......
......@@ -15,26 +15,10 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
struct onnx_parser
{
......
add_bcast-example:
-
0
12"Add*
axis*
broadcasttest-add_bcastZ
0




Z
1


b
2




B
\ No newline at end of file
implicit_bcast-example:q

0
12"Addtest-multi_bcastZ
0




Z
1


b
2




B
\ No newline at end of file
......@@ -301,6 +301,54 @@ void atan_test()
EXPECT(p == prog);
}
void add_bcast_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
EXPECT(p == prog);
}
void implicit_bcast_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
EXPECT(p == prog);
}
void unknown_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog);
}
void softmax_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
......@@ -325,4 +373,7 @@ int main()
asin_test();
acos_test();
atan_test();
add_bcast_test();
implicit_bcast_test();
unknown_test();
}
softmax-example:I

01"Softmax test-softmaxZ
0


b
1


B
\ No newline at end of file
unknown-example:q

0
12"Unknown test-unknownZ
0




Z
1


b
2




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment