Commit 1adbdb5d authored by charlie's avatar charlie
Browse files

Review updates

parent 239d50dc
......@@ -51,7 +51,13 @@ struct reshape
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
int not_fixed_index = -1;
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension");
}
std::size_t not_fixed_index = 0;
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
......@@ -63,16 +69,9 @@ struct reshape
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(not_fixed_index == -1)
{
not_fixed_index = i;
}
else
{
MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
......@@ -85,19 +84,12 @@ struct reshape
"output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{
if(i == not_fixed_index)
{
output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
}
else
{
std::size_t d = dims[i];
output_dyn_dims.push_back({d, d, 0});
}
}
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [](auto in_d) {
std::size_t d = in_d;
return shape::dynamic_dimension{d, d};
});
output_dyn_dims[not_fixed_index] = dyn_dims[not_fixed_index];
return {s0.type(), output_dyn_dims};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment