Commit 983c7c1f authored by charlie's avatar charlie
Browse files

Simplfy further

parent 1adbdb5d
...@@ -57,7 +57,6 @@ struct reshape ...@@ -57,7 +57,6 @@ struct reshape
{ {
MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension"); MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension");
} }
std::size_t not_fixed_index = 0;
// track number of fixed elements in input and output // track number of fixed elements in input and output
std::size_t num_dims_ele = 1; std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1; std::size_t num_dd_ele = 1;
...@@ -70,7 +69,12 @@ struct reshape ...@@ -70,7 +69,12 @@ struct reshape
} }
else else
{ {
not_fixed_index = i; if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
} }
} }
if(num_dims_ele != num_dd_ele) if(num_dims_ele != num_dd_ele)
...@@ -78,18 +82,17 @@ struct reshape ...@@ -78,18 +82,17 @@ struct reshape
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " + MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele)); std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
} }
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
// construct output dynamic shape from dims attribute // construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size()); std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [](auto in_d) { std::transform(dims.cbegin(),
std::size_t d = in_d; dims.cend(),
return shape::dynamic_dimension{d, d}; dyn_dims.cbegin(),
}); output_dyn_dims.begin(),
output_dyn_dims[not_fixed_index] = dyn_dims[not_fixed_index]; [](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment