Commit 78c799c5 authored by charlie's avatar charlie
Browse files

Initial

parent b4bbdde5
......@@ -44,6 +44,7 @@ struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
std::vector<shape::dynamic_dimension> broadcast_dyn_dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -54,11 +55,12 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
......
......@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -46,44 +48,66 @@ struct multibroadcast
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
check_shapes{inputs, *this, true}.has(1, 2);
if(input.lens().empty())
auto t = inputs.at(0).type();
auto input_shape = inputs.at(0);
if(inputs.size() == 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input_shape.lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
if(input_shape.lens().size() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
auto offset = output_lens.size() - input_shape.lens().size();
for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
if(output_lens[i + offset] != input_shape.lens()[i] and input_shape.lens()[i] != 1)
{
MIGRAPHX_THROW(
"MULTIBROADCAST: input shape {" + to_string_range(input_shape.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) + "}!");
}
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input_shape.lens()[i])
{
bcast_strides[i + offset] = input_shape.strides()[i];
}
}
return {t, output_lens, bcast_strides};
}
else
{
if(output_lens[i + offset] == input.lens()[i])
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
auto other_shape = inputs.at(1);
if(input_shape.dynamic() and other_shape.dynamic()) {}
else if(not input_shape.dynamic() and not other_shape.dynamic())
{
bcast_strides[i + offset] = input.strides()[i];
auto output_lens = compute_broadcasted_lens(input_shape.lens(), other_shape.lens());
}
else
{
MIGRAPHX_THROW(
"MULTIBROADCAST: input_shape and other_shape are not both dynamic or static");
}
}
return {t, output_lens, bcast_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment