Commit eb8f205b authored by charlie's avatar charlie
Browse files

Dynamic weight handling

parent a0dd2ef9
......@@ -62,23 +62,21 @@ struct convolution
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
if(weights.dynamic())
{
MIGRAPHX_THROW("CONVOLUTION: dynamic weights not supported");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
const size_t num_spatial_dims = input_size - 2;
if(num_spatial_dims != this->kdims())
{
MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
}
if(!input.dynamic() and input.lens().at(1) != (weights.lens().at(1) * group))
if(not input.dynamic() and not weights.dynamic() and
input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
auto calc_output_lens =
[this, &weights, &num_spatial_dims, &padding_size](std::vector<std::size_t> lens) {
[this, &num_spatial_dims, &padding_size](std::vector<std::size_t> i_lens,
std::vector<std::size_t> w_lens) {
std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++)
......@@ -91,8 +89,7 @@ struct convolution
}
ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
(i_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
stride[i] +
1)));
}
......@@ -103,9 +100,9 @@ struct convolution
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0),
input.dyn_dims().at(1)};
auto min_spatial_dims = calc_output_lens(input.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens());
auto min_spatial_dims = calc_output_lens(input.min_lens(), weights.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens(), weights.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens(), weights.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
......@@ -116,7 +113,7 @@ struct convolution
else
{
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto spatial_lens = calc_output_lens(input.lens());
auto spatial_lens = calc_output_lens(input.lens(), weights.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x);
});
......
......@@ -312,7 +312,9 @@ std::vector<argument> generic_eval(const module* mod,
return shapes;
};
// TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes now so not recalculating
// Could memoize these output shapes into a map so not recalculating
// TODO: Issue with incompatible input tensor to kernel and needing to set
// padding/strides
output_shape = ins->get_operator().compute_shape(to_shapes(values));
}
else
......@@ -333,7 +335,6 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
// TODO: update this assert for dynamic shapes
if(not ins->get_shape().dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment