"vscode:/vscode.git/clone" did not exist on "4b70d68ecfbd0d38ed05cae534f34e6390180c19"
Unverified Commit a9d6071a authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Refactor dynamic_dimension fixed compare (#1470)

Implements the operator==(dynamic_dimension, size_t) functions
parent b41c1f01
...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform( std::transform(
s0.dyn_dims().cbegin(), s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(), s0.dyn_dims().cend(),
...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{ {
return a; return a;
} }
else if(a == one_dyn_dim or b == one_dyn_dim) else if(a == 1 or b == 1)
{ {
// setting opt to 0, may need to be changed // setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0}; return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
...@@ -59,9 +59,8 @@ struct squeeze ...@@ -59,9 +59,8 @@ struct squeeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != one_dyn_dim; return input_shape.dyn_dims()[axis] != 1;
})) }))
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
...@@ -73,7 +72,7 @@ struct squeeze ...@@ -73,7 +72,7 @@ struct squeeze
std::copy_if(input_shape.dyn_dims().cbegin(), std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(), input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims), std::back_inserter(dyn_dims),
[&](auto dd) { return dd != one_dyn_dim; }); [&](auto dd) { return dd != 1; });
} }
else else
{ {
......
...@@ -101,6 +101,12 @@ struct shape ...@@ -101,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) ...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return os; return os;
} }
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
......
...@@ -160,6 +160,20 @@ TEST_CASE(test_shape_dynamic_compares) ...@@ -160,6 +160,20 @@ TEST_CASE(test_shape_dynamic_compares)
EXPECT(ss0.str() != ss3.str()); EXPECT(ss0.str() != ss3.str());
} }
TEST_CASE(dynamic_dimension_size_t_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2};
EXPECT(a == 2);
EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0};
EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b);
}
TEST_CASE(test_shape_dynamic_errors) TEST_CASE(test_shape_dynamic_errors)
{ {
using migraphx::shape; using migraphx::shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment