"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "96f0f5672f75d42f31f03594082882ba24376b73"
Commit 2c95db34 authored by charlie's avatar charlie
Browse files

Tidy fixes

parent c9c9ef5c
...@@ -237,6 +237,56 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -237,6 +237,56 @@ void preview_argument(std::ostream& os, const argument& a)
}); });
} }
template <class Ins, class T, class F>
void process_op(context& ctx,
Ins ins,
T trace,
F make_trace,
std::vector<argument>& values,
std::unordered_map<instruction_ref, argument>& results)
{
values.resize(ins->inputs().size());
std::transform(
ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
shape output_shape;
auto ins_shape = ins->get_shape();
if(ins_shape.dynamic())
{
// Make into a std::vector<instruction_ref> of inputs
auto to_shapes = [](std::vector<argument> args) {
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](const argument& i) {
return i.get_shape();
});
return shapes;
};
// TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes into a map so not recalculating
// TODO: Issue with possibly wanting to use new padding/strides/dilation
output_shape = ins->get_operator().compute_shape(to_shapes(values));
}
else
{
output_shape = ins_shape;
}
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, output_shape, values, mod_args, module_eval);
}));
}
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
...@@ -292,46 +342,7 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -292,46 +342,7 @@ std::vector<argument> generic_eval(const module* mod,
} }
else else
{ {
values.resize(ins->inputs().size()); process_op(ctx, ins, trace, make_trace, values, results);
std::transform(
ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
shape output_shape;
auto ins_shape = ins->get_shape();
if(ins_shape.dynamic())
{
// Make into a std::vector<instruction_ref> of inputs
auto to_shapes = [](std::vector<argument> args) {
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](argument i) {
return i.get_shape();
});
return shapes;
};
// TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes into a map so not recalculating
// TODO: Issue with possibly wanting to use new padding/strides/dilation
output_shape = ins->get_operator().compute_shape(to_shapes(values));
}
else
{
output_shape = ins_shape;
}
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, output_shape, values, mod_args, module_eval);
}));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic()) if(not ins->get_shape().dynamic())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment