Commit 239d50dc authored by charlie's avatar charlie
Browse files

tidy & cppcheck fix

parent 86cf29ed
...@@ -47,14 +47,8 @@ struct reshape ...@@ -47,14 +47,8 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; } value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{ shape dyn_compute_shape(shape s0) const
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{ {
auto dyn_dims = s0.dyn_dims(); auto dyn_dims = s0.dyn_dims();
int not_fixed_index = -1; int not_fixed_index = -1;
...@@ -83,8 +77,7 @@ struct reshape ...@@ -83,8 +77,7 @@ struct reshape
if(num_dims_ele != num_dd_ele) if(num_dims_ele != num_dd_ele)
{ {
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " + MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
" Output: " + std::to_string(num_dims_ele));
} }
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1) if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{ {
...@@ -101,13 +94,15 @@ struct reshape ...@@ -101,13 +94,15 @@ struct reshape
} }
else else
{ {
auto d = static_cast<std::size_t>(dims[i]); std::size_t d = dims[i];
output_dyn_dims.push_back({d, d, 0}); output_dyn_dims.push_back({d, d, 0});
} }
} }
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
else
template <class T>
shape static_compute_shape(std::vector<shape> inputs, T n_neg_dims) const
{ {
check_shapes{inputs, *this}.standard(); check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
...@@ -143,6 +138,22 @@ struct reshape ...@@ -143,6 +138,22 @@ struct reshape
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return s; return s;
} }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
......
...@@ -2000,7 +2000,7 @@ TEST_CASE(reshape_dyn_shape) ...@@ -2000,7 +2000,7 @@ TEST_CASE(reshape_dyn_shape)
} }
else else
{ {
auto d = static_cast<std::size_t>(new_shape[i]); std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0}); out_dyn_dims.push_back({d, d, 0});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment