Commit 12f78eec authored by charlie's avatar charlie
Browse files

Revert broadcast.hpp changes

Trying to keep the PRs separate
parent 02ef1a0c
......@@ -27,7 +27,6 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,39 +54,37 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.at(0);
auto t = s0.type();
auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
if(broadcast_lens.size() - axis < input.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment