Commit d0aae8be authored by charlie's avatar charlie
Browse files

Add operator+= and operator-= for dyn_dim

parent c3bb72ac
......@@ -66,7 +66,7 @@ struct pad
auto out_dyn_dims = s0.dyn_dims();
for(std::size_t i = 0; i < s0.ndim(); ++i)
{
out_dyn_dims[i] = out_dyn_dims[i] + pads[i] + pads[i + s0.ndim()];
out_dyn_dims[i] += pads[i] + pads[i + s0.ndim()];
}
return {s0.type(), out_dyn_dims};
}
......
......@@ -109,6 +109,8 @@ struct shape
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
// add and subtract fixed std::size_t dimension
dynamic_dimension& operator+=(const std::size_t& x);
dynamic_dimension& operator-=(const std::size_t& x);
friend dynamic_dimension operator+(const dynamic_dimension& x, const std::size_t& y);
friend dynamic_dimension operator+(const std::size_t& x, const dynamic_dimension& y);
friend dynamic_dimension operator-(const dynamic_dimension& x, const std::size_t& y);
......
......@@ -504,6 +504,28 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
this->min += x;
this->max += x;
this->opt == 0 ? this->opt = 0 : this->opt += x;
return *this;
}
shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
{
assert(this->min >= x);
assert(this->max >= x);
this->min -= x;
this->max -= x;
if(this->opt != 0)
{
assert(this->opt >= y);
this->opt -= x;
}
return *this;
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
// don't check opt if both are fixed
......@@ -531,25 +553,19 @@ bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { retur
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
{
return {x.min + y, x.max + y, x.opt == 0 ? 0 : x.opt + y};
auto dd = x;
return dd += y;
}
shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
{
return y + x;
}
shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
{
assert(x.min >= y);
assert(x.max >= y);
if(x.opt == 0)
{
return {x.min - y, x.max - y, 0};
}
else
{
assert(x.opt >= y);
return {x.min - y, x.max - y, x.opt - y};
}
auto dd = x;
return dd -= y;
}
bool operator==(const shape& x, const shape& y)
......
......@@ -179,6 +179,11 @@ TEST_CASE(dynamic_dimension_add_sub_fixed)
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
a += 3;
EXPECT(a == shape::dynamic_dimension{5, 8, 5});
a -= 3;
EXPECT(a == shape::dynamic_dimension{2, 5, 2});
auto b = shape::dynamic_dimension{3, 6, 3};
EXPECT((a + 1) == b);
EXPECT((1 + a) == b);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment