Commit 5c7a1969 authored by charlie's avatar charlie
Browse files

Initial case

parent 56c43445
...@@ -48,12 +48,66 @@ struct reshape ...@@ -48,12 +48,66 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
std::size_t not_fixed_index = 0;
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(not_fixed_index == 0)
{
not_fixed_index = i;
}
else
{
MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
}
}
}
if (num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " + std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{
if(i == not_fixed_index)
{
output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
}
else
{
auto d = static_cast<std::size_t>(dims[i]);
output_dyn_dims.push_back({d, d, 0});
}
}
return {s0.type(), output_dyn_dims};
}
else
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
...@@ -85,6 +139,7 @@ struct reshape ...@@ -85,6 +139,7 @@ struct reshape
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return s; return s;
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported"); check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment