Commit 553a8d02 authored by charlie's avatar charlie
Browse files

refactor again, made a compute broadcast for dyn_dims

parent c539a7b0
......@@ -68,48 +68,46 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
if(s0.dynamic() or s1.dynamic())
if(not s0.dynamic() and not s1.dynamic())
{
// change both shapes to dynamic_dimension representation
if(not s0.dynamic())
s0 = s0.to_dynamic();
if(not s1.dynamic())
s1 = s1.to_dynamic();
if(s0.rank() > s1.rank())
{
std::swap(s0, s1);
}
auto offset = s1.rank() - s0.rank();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b))
{
return shape::dynamic_dimension{
std::max(a.min, b.min), std::max(a.max, b.max), std::max(a.opt, b.opt)};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes");
}
else
// change both shapes to dynamic_dimension representation
if(not s0.dynamic())
s0 = s0.to_dynamic();
if(not s1.dynamic())
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes");
std::swap(s0, s1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b))
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_dims;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
......@@ -183,11 +181,19 @@ instruction_ref insert_common_op(module& m,
// multibroadcast?
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[0] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[0], inputs[1]);
inputs[0] =
m.insert_instruction(ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}),
inputs[0],
inputs[1]);
}
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(ins, make_op("multibroadcast"), inputs[1], inputs[0]);
inputs[1] =
m.insert_instruction(ins,
make_op("multibroadcast", {{"out_dyn_dims", c_dyn_dims}}),
inputs[1],
inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
......
......@@ -37,6 +37,8 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m,
......
......@@ -58,30 +58,35 @@ struct broadcast
check_shapes{inputs, *this, true}.has(1, 2);
auto s0 = inputs.at(0);
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
if(inputs.size() == 1)
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
}
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
else
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
return output;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -43,11 +43,12 @@ namespace op {
struct multibroadcast
{
std::vector<std::size_t> output_lens;
std::vector<shape::dynamic_dimension> output_dyn_dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "out_lens"));
return pack(f(self.output_lens, "out_lens"), f(self.output_dyn_dims, "out_dyn_dims"));
}
std::string name() const { return "multibroadcast"; }
......@@ -101,62 +102,13 @@ struct multibroadcast
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic() and s1.dynamic())
if(s0.dynamic() or s1.dynamic())
{
// TODO handle both dynamic case
MIGRAPHX_THROW(
"MULTIBROADCAST_2IN: not handled; two dynamic shape inputs not handled");
}
else if(s0.dynamic() or s1.dynamic())
{
// only handles the case when broadcasting static shape to dynamic shape
// all the dimensions in the static shape must match to a fixed dimension in the
// dynamic shape or be 1
// TODO: handling the other possibilities
if(s1.dynamic())
{
std::swap(s0, s1);
}
auto static_rank = s1.lens().size();
auto dyn_rank = s0.max_lens().size();
if(static_rank > dyn_rank)
if(not output_dyn_dims.empty())
{
MIGRAPHX_THROW("MULTIBROADCAST_2IN: not handled; static shape has a higher "
"rank than dynamic shape");
return {t, output_dyn_dims};
}
return s0;
auto offset = dyn_rank - static_rank;
std::vector<shape::dynamic_dimension> out_dims(s0.dyn_dims());
std::transform(s0.dyn_dims().begin(),
s0.dyn_dims().end(),
s1.lens().begin() + offset,
out_lens.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if((a == 1 or b == 1) and a != 0 and b != 0)
{
return std::max(a, b);
}
else
{
// if not matching nor 1, set to 0
return static_cast<std::size_t>(0);
}
});
/*
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
return {t,
std::move(bcast_min_lens),
std::move(bcast_max_lens),
std::move(bcast_opt_lens)};
*/
return {t, compute_broadcasted_dyn_dims(s0, s1)};
}
else
{
......
......@@ -142,6 +142,12 @@ struct shape
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std::size_t ndim() const;
/*!
* Return the number of elements in the tensor.
*/
......@@ -227,6 +233,9 @@ struct shape
shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape
shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
......@@ -265,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::ndim() const
{
if(this->dynamic())
{
return dyn_dims().size();
}
return lens().size();
}
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
......@@ -458,6 +467,15 @@ shape shape::with_type(type_t t) const
return {c};
}
shape shape::to_dynamic() const
{
if(this->dynamic())
{
return *this;
}
return {type(), lens(), lens(), lens()};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment