"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "582404b97fe113c52f256fade220bbbf047d7d58"
Commit 79c38ace authored by charlie's avatar charlie
Browse files

Refactor compute_shape() call into op.compute()

Allows for per operator differences with handling dynamic shape
Fix operation.hpp change to use the generator
parent 50005acf
...@@ -68,8 +68,10 @@ struct operation ...@@ -68,8 +68,10 @@ struct operation
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`. * For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
......
...@@ -279,56 +279,6 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -279,56 +279,6 @@ void preview_argument(std::ostream& os, const argument& a)
}); });
} }
template <class Ins, class T, class F>
void process_op(context& ctx,
Ins ins,
T trace,
F make_trace,
std::vector<argument>& values,
std::unordered_map<instruction_ref, argument>& results)
{
values.resize(ins->inputs().size());
std::transform(
ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
shape output_shape;
auto ins_shape = ins->get_shape();
if(ins_shape.dynamic())
{
// Make into a std::vector<instruction_ref> of inputs
auto to_shapes = [](std::vector<argument> args) {
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](const argument& i) {
return i.get_shape();
});
return shapes;
};
// TODO: Consider how this will be handled when memoized.
// Could memoize these output shapes into a map so not recalculating
// TODO: Issue with possibly wanting to use new padding/strides/dilation
output_shape = ins->get_operator().compute_shape(to_shapes(values));
}
else
{
output_shape = ins_shape;
}
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, output_shape, values, mod_args, module_eval);
}));
}
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
...@@ -356,12 +306,13 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -356,12 +306,13 @@ std::vector<argument> generic_eval(const module* mod,
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
// TODO: may want to check correct number of dimensions and/or was within bounds // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name);
}
return param; return param;
})); }));
} }
...@@ -384,7 +335,24 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -384,7 +335,24 @@ std::vector<argument> generic_eval(const module* mod,
} }
else else
{ {
process_op(ctx, ins, trace, make_trace, values, results); values.resize(ins->inputs().size());
std::transform(
ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values, mod_args, module_eval);
}));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic()) if(not ins->get_shape().dynamic())
......
...@@ -231,11 +231,13 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -231,11 +231,13 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
{ {
return op.normalize_compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
if(output_shape.dynamic()) if(output_shape.dynamic())
{ {
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()}); output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
} }
argument result{output_shape}; argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
......
...@@ -68,8 +68,10 @@ struct operation ...@@ -68,8 +68,10 @@ struct operation
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`. * For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment