Commit 700d7761 authored by charlie's avatar charlie
Browse files

cleanup

parent 439f96bc
......@@ -45,40 +45,9 @@ struct select_module
shape compute_shape(const std::vector<shape>&, std::vector<module_ref>) const
{
// if(std::none_of(inputs.cbegin(), inputs.cend(), [](auto input){ return input.dynamic();
// }))
//{
// if(mods.size() != 1)
// {
// MIGRAPHX_THROW("SELECT_MODULE: operator should have one submodule during eval.");
// }
// return {mods.front()->get_output_shapes()};
//}
return shape{output_dyn_shapes};
}
std::vector<std::string> get_input_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return not contains(pn, "#output_"); });
return ret;
}
std::vector<std::string> get_output_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return contains(pn, "#output_"); });
return ret;
}
argument compute(const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& submodule_list,
......@@ -89,10 +58,10 @@ struct select_module
// assuming arguments are in the same order as the input parameters
auto module_iter =
std::find_if(submodule_list.cbegin(), submodule_list.cend(), [&](module_ref mr) {
auto input_param_names = get_input_parameter_names(mr);
assert(input_param_names.size() <= args.size());
return std::equal(input_param_names.cbegin(),
input_param_names.cend(),
auto param_names = mr->get_parameter_names();
assert(param_names.size() <= args.size());
return std::equal(param_names.cbegin(),
param_names.cend(),
args.cbegin(),
[&](auto p_name, auto a) {
return a.get_shape() == mr->get_parameter_shape(p_name);
......@@ -107,24 +76,14 @@ struct select_module
std::unordered_map<std::string, argument> params;
// add input parameters
auto input_param_names = get_input_parameter_names(module_to_run);
assert(input_param_names.size() <= args.size());
std::transform(input_param_names.begin(),
input_param_names.end(),
auto param_names = module_to_run->get_parameter_names();
assert(param_names.size() <= args.size());
std::transform(param_names.begin(),
param_names.end(),
args.begin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); });
// add output parameters from arguments (none if on ref)
// auto output_param_names = get_output_parameter_names(module_to_run);
// std::transform(output_param_names.begin(),
// output_param_names.end(),
// args.begin() + input_param_names.size(),
// std::inserter(params, params.end()),
// [](auto&& name, auto&& a) {
// return std::make_pair(name, a);
// });
auto results = run(module_to_run, params);
return argument{results};
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SELECT_MODULE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_select_module
{
std::string name() const { return "gpu::select_module"; }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const;
argument
compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment