Commit 47a07c3a authored by charlie's avatar charlie
Browse files

add dynamic_dimension.within_range()

parent 0ef0d0bb
...@@ -61,10 +61,6 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0, ...@@ -61,10 +61,6 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
} }
auto offset = dds1.size() - dds0.size(); auto offset = dds1.size() - dds0.size();
std::vector<shape::dynamic_dimension> out_dims(dds1); std::vector<shape::dynamic_dimension> out_dims(dds1);
// If one within the range of the other
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
std::transform(dds0.cbegin(), std::transform(dds0.cbegin(),
dds0.cend(), dds0.cend(),
dds1.cbegin() + offset, dds1.cbegin() + offset,
...@@ -78,11 +74,11 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0, ...@@ -78,11 +74,11 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
{ {
return b; return b;
} }
else if(dd_within_range(a, b)) else if(a.within_range(b))
{ {
return a; return a;
} }
else if(dd_within_range(b, a)) else if(b.within_range(a))
{ {
return b; return b;
} }
......
...@@ -53,38 +53,55 @@ struct dot ...@@ -53,38 +53,55 @@ struct dot
} }
if(a.dynamic() or b.dynamic()) if(a.dynamic() or b.dynamic())
{ {
auto dd_within_range = [&](shape::dynamic_dimension x, shape::dynamic_dimension y) {
return (x.min >= y.min and x.max <= y.max);
};
auto s0 = a.to_dynamic(); auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic(); auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2, std::vector<shape::dynamic_dimension> out_dyn_dims;
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2, // check outer dimensions are within range
s1.dyn_dims().rend(), // put within range dynamic_dimensions into the out_dyn_dims
[&](auto x, auto y) { bool outers_within_range = std::equal(s0.dyn_dims().rbegin() + 2,
return (dd_within_range(x, y) or dd_within_range(y, x)); s0.dyn_dims().rend(),
})) s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend(),
[&](auto x, auto y) {
if(x.within_range(y))
{
out_dyn_dims.push_back(x);
return true;
}
else if(y.within_range(x))
{
out_dyn_dims.push_back(y);
return true;
}
else
{
return false;
}
});
if(not outers_within_range)
{ {
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within " MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {" + "dynamic_dimension range: {" +
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
std::size_t dim_0 = s0.ndim() - 2; std::size_t dim_i = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1; std::size_t dim_j = s0.ndim() - 1;
auto x = s0.dyn_dims()[dim_1]; auto x = s0.dyn_dims()[dim_i];
auto y = s1.dyn_dims()[dim_0]; auto y = s1.dyn_dims()[dim_j];
if(not dd_within_range(x, y) and not dd_within_range(y, x))
// check inner dimensions are within range
if(not x.within_range(y) and not y.within_range(x))
{ {
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" + MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" + to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}"); to_string_range(s1.dyn_dims()) + "}");
} }
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0. out_dyn_dims.push_back(s0.dyn_dims()[dim_i]);
auto out_dyn_dims = s0.dyn_dims(); out_dyn_dims.push_back(s1.dyn_dims()[dim_j]);
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims}; return {t, out_dyn_dims};
} }
else else
......
...@@ -102,6 +102,11 @@ struct MIGRAPHX_EXPORT shape ...@@ -102,6 +102,11 @@ struct MIGRAPHX_EXPORT shape
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
bool within_range(const dynamic_dimension& other)
{
return (this->min >= other.min and this->max <= other.max);
}
MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x, MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x,
const dynamic_dimension& y); const dynamic_dimension& y);
MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment