Commit 239d50dc authored by charlie's avatar charlie
Browse files

tidy & cppcheck fix

parent 86cf29ed
......@@ -47,14 +47,8 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
int not_fixed_index = -1;
......@@ -83,8 +77,7 @@ struct reshape
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) +
" Output: " + std::to_string(num_dims_ele));
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
......@@ -101,13 +94,15 @@ struct reshape
}
else
{
auto d = static_cast<std::size_t>(dims[i]);
std::size_t d = dims[i];
output_dyn_dims.push_back({d, d, 0});
}
}
return {s0.type(), output_dyn_dims};
}
else
template <class T>
shape static_compute_shape(std::vector<shape> inputs, T n_neg_dims) const
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens();
......@@ -143,6 +138,22 @@ struct reshape
std::to_string(inputs.front().elements()));
return s;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
......
......@@ -2000,7 +2000,7 @@ TEST_CASE(reshape_dyn_shape)
}
else
{
auto d = static_cast<std::size_t>(new_shape[i]);
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment