Commit 459eb3dd authored by charlie's avatar charlie
Browse files

Change code to use operator compare

parent 564d38e5
......@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
......@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return a;
}
else if(a == one_dyn_dim or b == one_dyn_dim)
else if(a == 1 or b == 1)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
......@@ -59,9 +59,8 @@ struct squeeze
auto input_shape = inputs[0];
if(input_shape.dynamic())
{
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != one_dyn_dim;
return input_shape.dyn_dims()[axis] != 1;
}))
{
MIGRAPHX_THROW(
......@@ -73,7 +72,7 @@ struct squeeze
for(auto i : range(input_shape.ndim()))
{
auto dd = input_shape.dyn_dims()[i];
if(dd != one_dyn_dim)
if(dd != 1)
{
dyn_dims.push_back(dd);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment