"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "75fc9a464871d7ab7c5676b78c1254951dcf6104"
Commit 32b1f924 authored by charlie's avatar charlie
Browse files

progress

parent 0b0a6d4f
...@@ -33,100 +33,62 @@ namespace migraphx { ...@@ -33,100 +33,62 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// Make this work just for exact matches
// can get rid of the other attributes and just check all the parameters are the same
// GPU version of this might have to deal with output parameters
// see loop op for how the output parameters are dealt with there
// Can have multiple inputs but only one output?
struct select_module struct select_module
{ {
// output shape of the dynamic model // output shape of the dynamic model
shape output_dyn_shape; shape output_dyn_shape;
int input_batch_index = -1;
int output_batch_index = -1;
std::string dyn_batch_param_name;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_dyn_shape, "output_dyn_shape"), return pack(f(self.output_dyn_shape, "output_dyn_shape"));
f(self.input_batch_index, "input_batch_index"),
f(self.output_batch_index, "output_batch_index"),
f(self.dyn_batch_param_name, "dyn_batch_param_name"));
} }
std::string name() const { return "select_module"; } std::string name() const { return "select_module"; }
// runs once during model compilation with dynamic shape input
// may run on each model evaluation with static shape input
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true};
auto s0 = inputs.at(0); return shape{output_dyn_shape};
if(s0.dynamic())
{
// should we check that the submodules have the same parameters here?
// check that no more than one parameter is non-fixed?
// would need to use version of compute_shape with the parameter list
return shape{output_dyn_shape};
}
else
{
auto batch_size = s0.lens().at(input_batch_index);
auto dds = output_dyn_shape.dyn_dims();
dds.at(output_batch_index) = {batch_size, batch_size};
std::vector<std::size_t> dims;
if(std::all_of(dds.begin(), dds.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::transform(
dds.begin(), dds.end(), std::back_inserter(dims), [](auto d) { return d.max; });
return {output_dyn_shape.type(), dims};
}
else
{
MIGRAPHX_THROW("SELECT_MODULE: more than one input dimension was non-fixed");
}
}
} }
argument compute(const dyn_output& dyn_out, argument compute(const shape&,
const std::vector<argument>& args, const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list, const std::vector<module_ref>& submodule_list,
const std::function<std::vector<argument>( const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{ {
std::vector<module_ref> modules_to_run; // find submodule with parameter shapes exactly the same as the input arguments
for(const auto& mod : submodule_list) // assuming arguments are in the same order as the parameters
{ auto module_to_run = std::find_if(submodule_list.begin(), submodule_list.end(), [&](module_ref mr) {
// find submodule with the same parameter shape as the input data auto param_names = mr.get_parameter_names();
auto p_shape = mod->get_parameter_shape(dyn_batch_param_name); std::equal(args.cbegin(), args.cend(), param_names.cbegin(), [&](auto a, auto p_name) {
if(p_shape == args.at(0).get_shape()) return a.get_shape() == mr.get_parameter_shape(p_name);
{ });
modules_to_run.push_back(mod); });
break;
}
}
// TODO if an exact match is not found, assemble module list from binary base
if(modules_to_run.empty()) if(module_to_run == submodule_list.end())
{ {
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for input shape: " + MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
migraphx::to_string(args.at(0).get_shape()));
} }
std::set<std::string> pnames;
for(const auto& mod : modules_to_run) auto param_names = module_to_run.get_parameter_names();
{
// TODO If all the modules have the same parameters, this would only need to run once
auto names = mod->get_parameter_names();
pnames.insert(names.begin(), names.end());
}
assert(pnames.size() <= args.size()); assert(pnames.size() <= args.size());
std::unordered_map<std::string, argument> params; std::unordered_map<std::string, argument> params;
std::transform(pnames.begin(), std::transform(param_names.begin(),
pnames.end(), param_names.end(),
args.begin(), args.begin(),
std::inserter(params, params.end()), std::inserter(params, params.end()),
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
// TODO run multiple modules and split the parameter data to each batch size auto results = run(module_to_run, params);
auto results = run(modules_to_run.at(0), params); return argument{results};
return results.at(0);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment