Commit 412c298e authored by charlie's avatar charlie
Browse files

Progress on changing ops

multibroadcast and broadcast take two inputs
parent 78c799c5
...@@ -44,7 +44,6 @@ struct broadcast ...@@ -44,7 +44,6 @@ struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens;
std::vector<shape::dynamic_dimension> broadcast_dyn_dims;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -55,38 +54,72 @@ struct broadcast ...@@ -55,38 +54,72 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto input = inputs.at(0); auto s0 = inputs.at(0);
auto t = input.type(); auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); if (inputs.size() == 1)
// the broadcast op is deprecated now, so not handling the negative {
// value of axis anymore std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(axis >= broadcast_lens.size()) // the broadcast op is deprecated now, so not handling the negative
{ // value of axis anymore
MIGRAPHX_THROW("BROADCAST : axis is out of range"); if(axis >= broadcast_lens.size())
} {
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < input.lens().size()) if(broadcast_lens.size() - axis < s0.lens().size())
{ {
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
} }
if(not std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements()) if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output; return output;
}
else
{
if(s0.dynamic() and s1.dynamic())
{
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens());
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {t, output_dyn_dims};
}
else if(not s0.dynamic() and not s1.dynamic())
{
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
}
else
{
MIGRAPHX_THROW(
"BROADCAST: s0 and s1 are not both dynamic or static");
}
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -51,39 +51,47 @@ struct multibroadcast ...@@ -51,39 +51,47 @@ struct multibroadcast
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has(1, 2);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input_shape = inputs.at(0); auto s0 = inputs.at(0);
if(s0.lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
auto make_bcast_strides = [&](std::size_t out_num_dims, std::size_t offset)
{
std::vector<size_t> bcast_strides(out_num_dims, 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == s0.lens()[i])
{
bcast_strides[i + offset] = s0.strides()[i];
}
}
return bcast_strides;
};
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
if(input_shape.lens().empty()) if(s0.lens().size() > output_lens.size())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input_shape.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
} }
auto offset = output_lens.size() - input_shape.lens().size(); auto offset = output_lens.size() - s0.lens().size();
for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] != input_shape.lens()[i] and input_shape.lens()[i] != 1) if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"MULTIBROADCAST: input shape {" + to_string_range(input_shape.lens()) + "MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) + "}!"); "} cannot be broadcasted to {" + to_string_range(output_lens) + "}!");
} }
} }
std::vector<size_t> bcast_strides(output_lens.size(), 0); auto bcast_strides = make_bcast_strides(output_lens.size(), offset);
for(std::ptrdiff_t i = input_shape.lens().size() - 1; i >= 0; i--) return {t, output_lens, std::move(bcast_strides)};
{
if(output_lens[i + offset] == input_shape.lens()[i])
{
bcast_strides[i + offset] = input_shape.strides()[i];
}
}
return {t, output_lens, bcast_strides};
} }
else else
{ {
...@@ -91,16 +99,37 @@ struct multibroadcast ...@@ -91,16 +99,37 @@ struct multibroadcast
// shapes can be dynamic (at compile-time) or static (at evaluation time) // shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output // this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes // new compute_broadcasted_lens for dynamic shapes
auto other_shape = inputs.at(1); // do we want this to work in both broadcast directions?
if(input_shape.dynamic() and other_shape.dynamic()) {} // s0 and s1 as shape inputs
else if(not input_shape.dynamic() and not other_shape.dynamic()) // always s0 -> s1 shape or allow s0 to retain the same shape?
// presuming that it's always s0 -> s1 shape, since that's closer to the current behavior
// compute_broadcasted_lens() will swap the shapes if s1.size() < s0.size(), may need to make another function
auto s1 = inputs.at(1);
if(s0.dynamic() and s1.dynamic())
{
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_opt_lens = compute_broadcasted_lens(s0.opt_lens(), s1.opt_lens());
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(size_t i = 0; i < bcast_max_lens.size(); ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {t, std::move(output_dyn_dims)};
}
else if(not s0.dynamic() and not s1.dynamic())
{ {
auto output_lens = compute_broadcasted_lens(input_shape.lens(), other_shape.lens()); auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
auto offset = s1.lens().size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(s1.lens().size(), offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
else else
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"MULTIBROADCAST: input_shape and other_shape are not both dynamic or static"); "MULTIBROADCAST: s0 and s1 are not both dynamic or static");
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment