Commit 239d50dc authored by charlie's avatar charlie
Browse files

tidy & cppcheck fix

parent 86cf29ed
...@@ -47,101 +47,112 @@ struct reshape ...@@ -47,101 +47,112 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; } value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
shape dyn_compute_shape(shape s0) const
{ {
check_shapes{inputs, *this, true}.has(1); auto dyn_dims = s0.dyn_dims();
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); int not_fixed_index = -1;
if(n_neg_dims > 1) // track number of fixed elements in input and output
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); std::size_t num_dims_ele = 1;
auto s0 = inputs[0]; std::size_t num_dd_ele = 1;
if(s0.dynamic()) for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{ {
auto dyn_dims = s0.dyn_dims(); if(dyn_dims[i].is_fixed())
int not_fixed_index = -1; {
// track number of fixed elements in input and output num_dims_ele *= dims[i];
std::size_t num_dims_ele = 1; num_dd_ele *= dyn_dims[i].min;
std::size_t num_dd_ele = 1; }
for(std::size_t i = 0; i < dyn_dims.size(); ++i) else
{ {
if(dyn_dims[i].is_fixed()) if(not_fixed_index == -1)
{ {
num_dims_ele *= dims[i]; not_fixed_index = i;
num_dd_ele *= dyn_dims[i].min;
} }
else else
{ {
if(not_fixed_index == -1) MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
{
not_fixed_index = i;
}
else
{
MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
}
} }
} }
if(num_dims_ele != num_dd_ele) }
{ if(num_dims_ele != num_dd_ele)
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " + {
std::to_string(num_dd_ele) + MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
" Output: " + std::to_string(num_dims_ele)); std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
} }
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1) if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{
if(i == not_fixed_index)
{ {
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 " output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
"output dimension");
} }
// construct output dynamic shape from dims attribute else
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{ {
if(i == not_fixed_index) std::size_t d = dims[i];
{ output_dyn_dims.push_back({d, d, 0});
output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
}
else
{
auto d = static_cast<std::size_t>(dims[i]);
output_dyn_dims.push_back({d, d, 0});
}
} }
return {s0.type(), output_dyn_dims};
} }
else return {s0.type(), output_dyn_dims};
}
template <class T>
shape static_compute_shape(std::vector<shape> inputs, T n_neg_dims) const
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{ {
check_shapes{inputs, *this}.standard(); if(dims[i] == 0)
auto&& idims = inputs.front().lens(); rdims[i] = idims[i];
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++) // since rdims using size_t type, -1 is the max value
{ // is size_t that cause later compuation incorrect
if(dims[i] == 0) if(dims[i] == -1)
rdims[i] = idims[i]; rdims[i] = 1;
}
// since rdims using size_t type, -1 is the max value if(n_neg_dims > 0)
// is size_t that cause later compuation incorrect {
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1) if(dims[i] == -1)
rdims[i] = 1; rdims[i] = missing_dim;
} }
}
if(n_neg_dims > 0) shape s{inputs.front().type(), rdims};
{ if(s.elements() != inputs.front().elements())
size_t missing_dim = MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
inputs.front().elements() / std::to_string(s.elements()) + " elements whereas the input has " +
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>()); std::to_string(inputs.front().elements()));
for(std::size_t i = 0; i < rdims.size(); i++) return s;
{ }
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
shape s{inputs.front().type(), rdims}; shape compute_shape(std::vector<shape> inputs) const
if(s.elements() != inputs.front().elements()) {
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + check_shapes{inputs, *this, true}.has(1);
std::to_string(s.elements()) + " elements whereas the input has " + auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
std::to_string(inputs.front().elements())); if(n_neg_dims > 1)
return s; MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
} }
} }
......
...@@ -2000,7 +2000,7 @@ TEST_CASE(reshape_dyn_shape) ...@@ -2000,7 +2000,7 @@ TEST_CASE(reshape_dyn_shape)
} }
else else
{ {
auto d = static_cast<std::size_t>(new_shape[i]); std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0}); out_dyn_dims.push_back({d, d, 0});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment