Commit 78c799c5 authored by charlie's avatar charlie
Browse files

Initial

parent b4bbdde5
...@@ -44,6 +44,7 @@ struct broadcast ...@@ -44,6 +44,7 @@ struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens;
std::vector<shape::dynamic_dimension> broadcast_dyn_dims;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -54,11 +55,12 @@ struct broadcast ...@@ -54,11 +55,12 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto t = input.type(); auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative // the broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -46,44 +48,66 @@ struct multibroadcast ...@@ -46,44 +48,66 @@ struct multibroadcast
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if(input.lens().empty()) auto t = inputs.at(0).type();
auto input_shape = inputs.at(0);
if(inputs.size() == 1)
{
if(input_shape.lens().empty())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0"); MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
} }
if(input.lens().size() > output_lens.size()) if(input_shape.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
} }
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input_shape.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1) if(output_lens[i + offset] != input_shape.lens()[i] and input_shape.lens()[i] != 1)
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) + MIGRAPHX_THROW(
"} cannot be broadcasted to {" + to_string_range(output_lens) + "MULTIBROADCAST: input shape {" + to_string_range(input_shape.lens()) +
"}!"); "} cannot be broadcasted to {" + to_string_range(output_lens) + "}!");
} }
} }
std::vector<size_t> bcast_strides(output_lens.size(), 0); std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input_shape.lens()[i])
{ {
bcast_strides[i + offset] = input.strides()[i]; bcast_strides[i + offset] = input_shape.strides()[i];
} }
} }
return {t, output_lens, bcast_strides}; return {t, output_lens, bcast_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const else
{
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
auto other_shape = inputs.at(1);
if(input_shape.dynamic() and other_shape.dynamic()) {}
else if(not input_shape.dynamic() and not other_shape.dynamic())
{
auto output_lens = compute_broadcasted_lens(input_shape.lens(), other_shape.lens());
}
else
{
MIGRAPHX_THROW(
"MULTIBROADCAST: input_shape and other_shape are not both dynamic or static");
}
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment