"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "4c0860697849547aa4bd3109068a681358a05ac6"
Commit 8e4b1022 authored by Scott Thornton's avatar Scott Thornton
Browse files

Made broadcast unary operator

parent 3d264140
...@@ -59,7 +59,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -59,7 +59,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1}, c, l_bias); auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
......
...@@ -618,34 +618,34 @@ struct flatten ...@@ -618,34 +618,34 @@ struct flatten
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
shape broadcast_shape;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto result = inputs.at(0); auto input = inputs.at(0);
auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0); std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
if(std::all_of( if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; })) broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, result.lens(), std::move(bcast_strides)}; return {t, broadcast_shape.lens(), std::move(bcast_strides)};
} }
else else
{ {
assert(result.lens().size() - axis >= input.lens().size()); assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis)) if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match"); MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)}; return {t, broadcast_shape.lens(), std::move(bcast_strides)};
} }
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(1).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
friend std::ostream& operator<<(std::ostream& os, const broadcast& op) friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{ {
......
...@@ -93,7 +93,7 @@ struct onnx_parser ...@@ -93,7 +93,7 @@ struct onnx_parser
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>() ? parse_value(attributes.at("axis")).at<uint64_t>()
: 0; : 0;
auto l = prog.add_instruction(op::broadcast{axis}, args); auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
} }
...@@ -131,7 +131,7 @@ struct onnx_parser ...@@ -131,7 +131,7 @@ struct onnx_parser
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]); auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis}, l1, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
return prog.add_instruction(op, args); return prog.add_instruction(op, args);
...@@ -223,7 +223,7 @@ struct onnx_parser ...@@ -223,7 +223,7 @@ struct onnx_parser
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = prog.add_instruction(op::gemm{alpha, beta}, l1, l2); auto l3 = prog.add_instruction(op::gemm{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis}, l3, args[2]); auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4); return prog.add_instruction(op::add{}, l3, l4);
} }
return prog.add_instruction(op::gemm{alpha, beta}, l1, l2); return prog.add_instruction(op::gemm{alpha, beta}, l1, l2);
......
#include <migraph/auto_contiguous.hpp> #include <migraph/auto_contiguous.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -57,7 +58,7 @@ void after_literal_broadcast() ...@@ -57,7 +58,7 @@ void after_literal_broadcast()
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2); auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -88,7 +89,7 @@ void after_param_broadcast() ...@@ -88,7 +89,7 @@ void after_param_broadcast()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2); auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <vector> #include <vector>
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/instruction.hpp>
#include <migraph/cpu/cpu_target.hpp> #include <migraph/cpu/cpu_target.hpp>
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -385,7 +386,7 @@ void broadcast_test() ...@@ -385,7 +386,7 @@ void broadcast_test()
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
p.add_instruction(migraph::op::broadcast{axis}, l1, l2); p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -404,7 +405,7 @@ void add_broadcast_test() ...@@ -404,7 +405,7 @@ void add_broadcast_test()
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::broadcast{axis}, l1, l2); auto l3 = p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l1, l3); p.add_instruction(migraph::op::add{}, l1, l3);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/onnx.hpp> #include <migraph/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -14,7 +15,7 @@ void pytorch_conv_bias_test() ...@@ -14,7 +15,7 @@ void pytorch_conv_bias_test()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l3, l4); p.add_instruction(migraph::op::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx"); auto prog = migraph::parse_onnx("conv.onnx");
...@@ -29,7 +30,7 @@ void pytorch_conv_relu_maxpool() ...@@ -29,7 +30,7 @@ void pytorch_conv_relu_maxpool()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5); auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -51,7 +52,7 @@ void pytorch_conv_bn_relu_maxpool() ...@@ -51,7 +52,7 @@ void pytorch_conv_bn_relu_maxpool()
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6); auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
...@@ -69,7 +70,7 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -69,7 +70,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5); auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -77,7 +78,7 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -77,7 +78,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8); auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::op::broadcast{axis}, l10, l9); auto l11 = p.add_instruction(migraph::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraph::op::add{}, l10, l11); auto l12 = p.add_instruction(migraph::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::op::activation{"relu"}, l12); auto l13 = p.add_instruction(migraph::op::activation{"relu"}, l12);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment