Unverified Commit 3d912972 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #74 from ROCmSoftwarePlatform/broadcast_unary

Broadcast unary
parents 0566387c 8ee940c1
......@@ -59,7 +59,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1}, c, l_bias);
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b});
}
}
......
......@@ -618,34 +618,36 @@ struct flatten
struct broadcast
{
uint64_t axis = 0;
shape broadcast_shape;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto result = inputs.at(0);
auto input = inputs.at(1);
auto input = inputs.at(0);
std::vector<size_t> bcast_strides(result.lens().size(), 0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
return x == 1;
}))
{
if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, result.lens(), std::move(bcast_strides)};
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
}
else
{
assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)};
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(1).data)};
return {std::move(output_shape), std::move(args.at(0).data)};
}
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
......
......@@ -93,7 +93,8 @@ struct onnx_parser
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(op::broadcast{axis}, args);
auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
}
......@@ -131,7 +132,7 @@ struct onnx_parser
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis}, l1, args[2]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2);
}
return prog.add_instruction(op, args);
......@@ -223,7 +224,7 @@ struct onnx_parser
{
uint64_t axis = 1;
auto l3 = prog.add_instruction(op::gemm{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis}, l3, args[2]);
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4);
}
return prog.add_instruction(op::gemm{alpha, beta}, l1, l2);
......
#include <migraph/auto_contiguous.hpp>
#include <migraph/operators.hpp>
#include <migraph/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -57,7 +58,7 @@ void after_literal_broadcast()
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2);
auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......@@ -88,7 +89,7 @@ void after_param_broadcast()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{}, l1, l2);
auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......
......@@ -2,6 +2,7 @@
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/instruction.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/verify.hpp>
#include "test.hpp"
......@@ -385,7 +386,7 @@ void broadcast_test()
uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
p.add_instruction(migraph::op::broadcast{axis}, l1, l2);
p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
auto output = result.get<int32_t>();
......@@ -404,7 +405,7 @@ void add_broadcast_test()
uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::broadcast{axis}, l1, l2);
auto l3 = p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l1, l3);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
......
......@@ -9,6 +9,7 @@
#include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <migraph/verify_args.hpp>
#include <migraph/instruction.hpp>
#include <miopen/miopen.h>
......@@ -181,7 +182,7 @@ struct test_add_broadcast
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraph::op::broadcast{0}, x, y);
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by);
return p;
}
......@@ -195,7 +196,7 @@ struct test_add_broadcast2
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by);
return p;
}
......@@ -209,7 +210,7 @@ struct test_add_broadcast3
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by);
return p;
}
......@@ -223,7 +224,7 @@ struct test_add_broadcast4
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by);
return p;
}
......@@ -237,7 +238,7 @@ struct test_add_broadcast5
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::op::broadcast{1}, x, y);
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by);
return p;
}
......
......@@ -3,6 +3,7 @@
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/onnx.hpp>
#include "test.hpp"
......@@ -14,7 +15,7 @@ void pytorch_conv_bias_test()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx");
......@@ -29,7 +30,7 @@ void pytorch_conv_relu_maxpool()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -51,7 +52,7 @@ void pytorch_conv_bn_relu_maxpool()
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
......@@ -69,7 +70,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
......@@ -77,7 +78,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::op::broadcast{axis}, l10, l9);
auto l11 = p.add_instruction(migraph::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraph::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::op::activation{"relu"}, l12);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment