Unverified Commit 803eb3ce authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #112 from ROCmSoftwarePlatform/multi_broadcast

Multi broadcast
parents d9b08400 bd60be01
...@@ -762,6 +762,48 @@ struct broadcast ...@@ -762,6 +762,48 @@ struct broadcast
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct multibroadcast
{
std::vector<std::size_t> output_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
}
std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if(input.lens().empty())
MIGRAPH_THROW("inputs dimensions should be > 0");
if(input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
for(int i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input.lens()[i])
{
bcast_strides[i + offset] = input.strides()[i];
}
}
return {t, output_lens, bcast_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct scalar struct scalar
{ {
shape scalar_bcast; shape scalar_bcast;
...@@ -789,7 +831,9 @@ struct binary ...@@ -789,7 +831,9 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
} }
}; };
......
...@@ -49,16 +49,17 @@ struct onnx_parser ...@@ -49,16 +49,17 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
add_generic_op("Add", op::add{});
add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::dot{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{});
// disable dropout for inference // disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
...@@ -92,12 +93,13 @@ struct onnx_parser ...@@ -92,12 +93,13 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_broadcastable_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() == 2 and contains(attributes, "broadcast")) if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
...@@ -109,7 +111,51 @@ struct onnx_parser ...@@ -109,7 +111,51 @@ struct onnx_parser
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args);
}
else
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
} }
});
}
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -607,8 +653,10 @@ struct onnx_parser ...@@ -607,8 +653,10 @@ struct onnx_parser
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform( std::transform(tensor_dims.begin(),
tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) { tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) -> std::size_t {
if(not d.has_dim_value()) if(not d.has_dim_value())
{ {
long default_batch_size = 1; // FIXME long default_batch_size = 1; // FIXME
......
...@@ -487,6 +487,7 @@ TEST_CASE(broadcast_test) ...@@ -487,6 +487,7 @@ TEST_CASE(broadcast_test)
} }
TEST_CASE(add_broadcast_test) TEST_CASE(add_broadcast_test)
{ {
{
migraph::program p; migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}}; migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
...@@ -504,6 +505,26 @@ TEST_CASE(add_broadcast_test) ...@@ -504,6 +505,26 @@ TEST_CASE(add_broadcast_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
}
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3};
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold));
}
} }
TEST_CASE(sub_test) TEST_CASE(sub_test)
......
...@@ -145,4 +145,68 @@ TEST_CASE(slice_shape) ...@@ -145,4 +145,68 @@ TEST_CASE(slice_shape)
migraph::op::slice{{2}, {2}, {10}}, migraph::op::slice{{2}, {2}, {10}},
input); input);
} }
TEST_CASE(multibroadcast)
{
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment