Commit 983c7c1f authored by charlie's avatar charlie
Browse files

Simplfy further

parent 1adbdb5d
......@@ -57,7 +57,6 @@ struct reshape
{
MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension");
}
std::size_t not_fixed_index = 0;
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
......@@ -70,7 +69,12 @@ struct reshape
}
else
{
not_fixed_index = i;
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
......@@ -78,18 +82,17 @@ struct reshape
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [](auto in_d) {
std::size_t d = in_d;
return shape::dynamic_dimension{d, d};
});
output_dyn_dims[not_fixed_index] = dyn_dims[not_fixed_index];
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment