Commit a465fc9d authored by charlie's avatar charlie
Browse files

Dynamic conv draft progress

parent 417d6644
...@@ -40,7 +40,7 @@ void auto_contiguous::apply(module& m) const ...@@ -40,7 +40,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
......
...@@ -17,6 +17,11 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -17,6 +17,11 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto val = op.to_value(); auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>(); auto op_padding = val.at("padding").to_vector<size_t>();
if(input->get_shape().dynamic())
{
return;
}
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(), if(std::equal(op_padding.begin(),
op_padding.begin() + kdims, op_padding.begin() + kdims,
......
...@@ -264,8 +264,10 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -264,8 +264,10 @@ std::vector<argument> generic_eval(const module* mod,
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
if(param.get_shape() != ins->get_shape()) // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name);
return param; return param;
...@@ -297,6 +299,25 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -297,6 +299,25 @@ std::vector<argument> generic_eval(const module* mod,
return results[i]; return results[i];
}); });
shape output_shape;
auto ins_shape = ins->get_shape();
if(ins_shape.dynamic())
{
// Make into a std::vector<instruction_ref> of inputs
auto to_shapes = [](std::vector<argument> args) {
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](argument i) {
return i.get_shape();
});
return shapes;
};
output_shape = ins->get_operator().compute_shape(to_shapes(values));
}
else
{
output_shape = ins_shape;
}
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
...@@ -306,11 +327,12 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -306,11 +327,12 @@ std::vector<argument> generic_eval(const module* mod,
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute( return ins->normalized_operator().compute(
ctx, ins->get_shape(), values, mod_args, module_eval); ctx, output_shape, values, mod_args, module_eval);
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape()); // TODO: update this assert for dynamic shapes
// assert(results.at(ins).get_shape() == ins->get_shape());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
......
...@@ -210,6 +210,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -210,6 +210,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
if(output_shape.dynamic())
{
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
argument result{output_shape}; argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_lens = input.get_shape().lens(); auto in_lens = input.get_shape().lens();
......
...@@ -855,11 +855,11 @@ TEST_CASE(conv_dynamic_batch_test) ...@@ -855,11 +855,11 @@ TEST_CASE(conv_dynamic_batch_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}}; {{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto input = mm->add_parameter("X", input_shape); auto input = mm->add_parameter("X", input_dyn_shape);
auto weights = mm->add_parameter("W", weights_shape); auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
input, input,
...@@ -910,8 +910,10 @@ TEST_CASE(conv_dynamic_batch_test) ...@@ -910,8 +910,10 @@ TEST_CASE(conv_dynamic_batch_test)
-0.16138598, -0.16138598,
0.79344082}; 0.79344082};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
migraphx::parameter_map params; migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data()); params["X"] = migraphx::argument(input_fixed_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data()); params["W"] = migraphx::argument(weights_shape, c.data());
auto result = p.eval(params).back(); auto result = p.eval(params).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment