"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "fee874e35a0b8cb79045ce878ed8436393aaf1cf"
Commit 12f78eec authored by charlie's avatar charlie
Browse files

Revert broadcast.hpp changes

Trying to keep the PRs separate
parent 02ef1a0c
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,39 +54,37 @@ struct broadcast ...@@ -55,39 +54,37 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); auto input = inputs.at(0);
auto s0 = inputs.at(0); auto t = input.type();
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative // the broacast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); MIGRAPHX_THROW("BROADCAST : axis is out of range");
} }
if(broadcast_lens.size() - axis < s0.lens().size()) if(broadcast_lens.size() - axis < input.lens().size())
{ {
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
} }
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis)) if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements()) if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size"); MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output; return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment