Commit eb8f205b authored by charlie's avatar charlie
Browse files

Dynamic weight handling

parent a0dd2ef9
...@@ -64,21 +64,19 @@ struct convolution ...@@ -64,21 +64,19 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
if(weights.dynamic())
{
MIGRAPHX_THROW("CONVOLUTION: dynamic weights not supported");
}
const size_t num_spatial_dims = input_size - 2; const size_t num_spatial_dims = input_size - 2;
if(num_spatial_dims != this->kdims()) if(num_spatial_dims != this->kdims())
{ {
MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size"); MIGRAPHX_THROW("CONVOLUTION: input k-dims does not match attribute size");
} }
if(!input.dynamic() and input.lens().at(1) != (weights.lens().at(1) * group)) if(not input.dynamic() and not weights.dynamic() and
input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
auto calc_output_lens = auto calc_output_lens =
[this, &weights, &num_spatial_dims, &padding_size](std::vector<std::size_t> lens) { [this, &num_spatial_dims, &padding_size](std::vector<std::size_t> i_lens,
std::vector<std::size_t> w_lens) {
std::vector<size_t> ret = {}; std::vector<size_t> ret = {};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1 // calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; i++) for(size_t i = 0; i < num_spatial_dims; i++)
...@@ -91,8 +89,7 @@ struct convolution ...@@ -91,8 +89,7 @@ struct convolution
} }
ret.push_back(std::size_t(std::max<std::ptrdiff_t>( ret.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (i_lens[i + 2] - (1 + dilation[i] * (w_lens[i + 2] - 1)) + padding_factor) /
padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -103,9 +100,9 @@ struct convolution ...@@ -103,9 +100,9 @@ struct convolution
{ {
std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0), std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0),
input.dyn_dims().at(1)}; input.dyn_dims().at(1)};
auto min_spatial_dims = calc_output_lens(input.min_lens()); auto min_spatial_dims = calc_output_lens(input.min_lens(), weights.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens()); auto max_spatial_dims = calc_output_lens(input.max_lens(), weights.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens()); auto opt_spatial_dims = calc_output_lens(input.opt_lens(), weights.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i) for(size_t i = 0; i < num_spatial_dims; ++i)
{ {
output_dyn_dims.push_back(shape::dynamic_dimension{ output_dyn_dims.push_back(shape::dynamic_dimension{
...@@ -116,7 +113,7 @@ struct convolution ...@@ -116,7 +113,7 @@ struct convolution
else else
{ {
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto spatial_lens = calc_output_lens(input.lens()); auto spatial_lens = calc_output_lens(input.lens(), weights.lens());
std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) { std::for_each(spatial_lens.begin(), spatial_lens.end(), [&output_lens](auto x) {
output_lens.push_back(x); output_lens.push_back(x);
}); });
......
...@@ -312,7 +312,9 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -312,7 +312,9 @@ std::vector<argument> generic_eval(const module* mod,
return shapes; return shapes;
}; };
// TODO: Consider how this will be handled when memoized. // TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes now so not recalculating // Could memoize these output shapes into a map so not recalculating
// TODO: Issue with incompatible input tensor to kernel and needing to set
// padding/strides
output_shape = ins->get_operator().compute_shape(to_shapes(values)); output_shape = ins->get_operator().compute_shape(to_shapes(values));
} }
else else
...@@ -333,7 +335,6 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -333,7 +335,6 @@ std::vector<argument> generic_eval(const module* mod,
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
// TODO: update this assert for dynamic shapes
if(not ins->get_shape().dynamic()) if(not ins->get_shape().dynamic())
{ {
assert(results.at(ins).get_shape() == ins->get_shape()); assert(results.at(ins).get_shape() == ins->get_shape());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment