Commit a88810da authored by charlie's avatar charlie
Browse files

Fix things

convolution revert
parent d9d2215a
...@@ -135,6 +135,7 @@ instruction_ref insert_common_op(module& m, ...@@ -135,6 +135,7 @@ instruction_ref insert_common_op(module& m,
a_input = m.insert_instruction( a_input = m.insert_instruction(
ins, make_op("convert", {{"target_type", c_type}}), a_input); ins, make_op("convert", {{"target_type", c_type}}), a_input);
} }
return a_input;
}); });
} }
else else
...@@ -151,6 +152,7 @@ instruction_ref insert_common_op(module& m, ...@@ -151,6 +152,7 @@ instruction_ref insert_common_op(module& m,
input = m.insert_instruction( input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input); ins, make_op("convert", {{"target_type", common.type()}}), input);
} }
return input;
}); });
} }
return m.insert_instruction(ins, op, inputs); return m.insert_instruction(ins, op, inputs);
......
...@@ -176,7 +176,11 @@ struct convolution ...@@ -176,7 +176,11 @@ struct convolution
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens()); auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens()); auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens()); auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
return shape{x_shape.type(), min_spatial_dims, max_spatial_dims, opt_spatial_dims}; for(size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
} }
return shape{x_shape.type(), output_dyn_dims}; return shape{x_shape.type(), output_dyn_dims};
} }
......
...@@ -237,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -237,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
{ {
} }
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment