Commit 26c33a16 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added multibroadcast + test

parent 2946e34e
......@@ -759,6 +759,53 @@ struct broadcast
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct multibroadcast
{
std::vector<std::size_t> output_lens;
std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if (input.lens().size() <= 0)
MIGRAPH_THROW("inputs dimensions should be > 0");
if (input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto extra = output_lens.size()-input.lens().size();
if (input.lens().size() < output_lens.size())
{
for (std::size_t i = output_lens.size()-1; i > 0; i--)
{
if (output_lens[i] == input.lens()[i-extra])
{
bcast_strides[i] = input.strides()[i-extra];
}
}
}
else
{
for (std::size_t i = 0; i < input.lens().size(); i++)
{
if (output_lens[i] == input.lens()[i])
{
bcast_strides[i] = input.strides()[i];
}
}
}
return {t, output_lens, bcast_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct scalar
{
shape scalar_bcast;
......
......@@ -145,8 +145,42 @@ void slice_shape()
migraph::op::slice{{2}, {2}, {10}},
input);
}
void multibroadcast_shape()
{
{
std::vector<std::size_t> lens{4,2,5,3};
migraph::shape input{migraph::shape::float_type, {2,1,3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,3,0,1}},
migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4,2,5,3};
migraph::shape input{migraph::shape::float_type, {2,1,1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,1,0,0}},
migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4,1,1,3};
migraph::shape input{migraph::shape::float_type, {4,1,1,1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1,1,1,0}},
migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4,1,3};
migraph::shape input{migraph::shape::float_type, {4,1,1,1}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4,1,3};
migraph::shape input{migraph::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input);
}
}
int main()
{
multibroadcast_shape();
batch_norm_inference_shape();
convolution_shape();
transpose_shape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment