Commit 32b1f924 authored by charlie's avatar charlie
Browse files

progress

parent 0b0a6d4f
......@@ -33,100 +33,62 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// Make this work just for exact matches
// can get rid of the other attributes and just check all the parameters are the same
// GPU version of this might have to deal with output parameters
// see loop op for how the output parameters are dealt with there
// Can have multiple inputs but only one output?
struct select_module
{
// output shape of the dynamic model
shape output_dyn_shape;
int input_batch_index = -1;
int output_batch_index = -1;
std::string dyn_batch_param_name;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_dyn_shape, "output_dyn_shape"),
f(self.input_batch_index, "input_batch_index"),
f(self.output_batch_index, "output_batch_index"),
f(self.dyn_batch_param_name, "dyn_batch_param_name"));
return pack(f(self.output_dyn_shape, "output_dyn_shape"));
}
std::string name() const { return "select_module"; }
// runs once during model compilation with dynamic shape input
// may run on each model evaluation with static shape input
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.at(0);
if(s0.dynamic())
{
// should we check that the submodules have the same parameters here?
// check that no more than one parameter is non-fixed?
// would need to use version of compute_shape with the parameter list
return shape{output_dyn_shape};
}
else
{
auto batch_size = s0.lens().at(input_batch_index);
auto dds = output_dyn_shape.dyn_dims();
dds.at(output_batch_index) = {batch_size, batch_size};
std::vector<std::size_t> dims;
if(std::all_of(dds.begin(), dds.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::transform(
dds.begin(), dds.end(), std::back_inserter(dims), [](auto d) { return d.max; });
return {output_dyn_shape.type(), dims};
}
else
{
MIGRAPHX_THROW("SELECT_MODULE: more than one input dimension was non-fixed");
}
}
check_shapes{inputs, *this, true};
return shape{output_dyn_shape};
}
argument compute(const dyn_output& dyn_out,
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
std::vector<module_ref> modules_to_run;
for(const auto& mod : submodule_list)
{
// find submodule with the same parameter shape as the input data
auto p_shape = mod->get_parameter_shape(dyn_batch_param_name);
if(p_shape == args.at(0).get_shape())
{
modules_to_run.push_back(mod);
break;
}
}
// TODO if an exact match is not found, assemble module list from binary base
// find submodule with parameter shapes exactly the same as the input arguments
// assuming arguments are in the same order as the parameters
auto module_to_run = std::find_if(submodule_list.begin(), submodule_list.end(), [&](module_ref mr) {
auto param_names = mr.get_parameter_names();
std::equal(args.cbegin(), args.cend(), param_names.cbegin(), [&](auto a, auto p_name) {
return a.get_shape() == mr.get_parameter_shape(p_name);
});
});
if(modules_to_run.empty())
if(module_to_run == submodule_list.end())
{
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for input shape: " +
migraphx::to_string(args.at(0).get_shape()));
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
}
std::set<std::string> pnames;
for(const auto& mod : modules_to_run)
{
// TODO If all the modules have the same parameters, this would only need to run once
auto names = mod->get_parameter_names();
pnames.insert(names.begin(), names.end());
}
auto param_names = module_to_run.get_parameter_names();
assert(pnames.size() <= args.size());
std::unordered_map<std::string, argument> params;
std::transform(pnames.begin(),
pnames.end(),
std::transform(param_names.begin(),
param_names.end(),
args.begin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
// TODO run multiple modules and split the parameter data to each batch size
auto results = run(modules_to_run.at(0), params);
return results.at(0);
auto results = run(module_to_run, params);
return argument{results};
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment