Commit 0b0a6d4f authored by charlie's avatar charlie
Browse files

First pass on the operator

only works if exact batch size submodule present
will need to make it assemble from other sizes later
parent 65590e8e
...@@ -181,13 +181,13 @@ struct context ...@@ -181,13 +181,13 @@ struct context
void wait_for(any_ptr queue) void wait_for(any_ptr queue)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue); (*this).private_detail_te_get_handle().wait_for(std::move(queue));
} }
void finish_on(any_ptr queue) void finish_on(any_ptr queue)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue); (*this).private_detail_te_get_handle().finish_on(std::move(queue));
} }
void finish() const void finish() const
...@@ -260,29 +260,29 @@ struct context ...@@ -260,29 +260,29 @@ struct context
template <class T> template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue) static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue)) -> decltype(private_detail_te_self.wait_for(std::move(queue)))
{ {
private_detail_te_self.wait_for(queue); private_detail_te_self.wait_for(std::move(queue));
} }
template <class T> template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue) static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{ {
wait_for_context(private_detail_te_self, queue); wait_for_context(private_detail_te_self, std::move(queue));
} }
template <class T> template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue) static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue)) -> decltype(private_detail_te_self.finish_on(std::move(queue)))
{ {
private_detail_te_self.finish_on(queue); private_detail_te_self.finish_on(std::move(queue));
} }
template <class T> template <class T>
static void static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue) private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{ {
finish_on_context(private_detail_te_self, queue); finish_on_context(private_detail_te_self, std::move(queue));
} }
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -302,7 +302,7 @@ struct context ...@@ -302,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(value) : private_detail_te_value(std::move(value))
{ {
} }
...@@ -334,13 +334,13 @@ struct context ...@@ -334,13 +334,13 @@ struct context
void wait_for(any_ptr queue) override void wait_for(any_ptr queue) override
{ {
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue); private_detail_te_default_wait_for(char(0), private_detail_te_value, std::move(queue));
} }
void finish_on(any_ptr queue) override void finish_on(any_ptr queue) override
{ {
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue); private_detail_te_default_finish_on(char(0), private_detail_te_value, std::move(queue));
} }
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
......
...@@ -52,8 +52,8 @@ struct select_module ...@@ -52,8 +52,8 @@ struct select_module
std::string name() const { return "select_module"; } std::string name() const { return "select_module"; }
// this should run once during model compilation with dynamic shape input // runs once during model compilation with dynamic shape input
// run once on each model evaluation with static shape input // may run on each model evaluation with static shape input
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
...@@ -95,7 +95,7 @@ struct select_module ...@@ -95,7 +95,7 @@ struct select_module
{ {
// find submodule with the same parameter shape as the input data // find submodule with the same parameter shape as the input data
auto p_shape = mod->get_parameter_shape(dyn_batch_param_name); auto p_shape = mod->get_parameter_shape(dyn_batch_param_name);
if(p_shape == dyn_out.computed_shape) if(p_shape == args.at(0).get_shape())
{ {
modules_to_run.push_back(mod); modules_to_run.push_back(mod);
break; break;
...@@ -105,17 +105,18 @@ struct select_module ...@@ -105,17 +105,18 @@ struct select_module
if(modules_to_run.empty()) if(modules_to_run.empty())
{ {
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found"); MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for input shape: " +
migraphx::to_string(args.at(0).get_shape()));
} }
std::set<std::string> pnames; std::set<std::string> pnames;
for(const auto& mod : modules_to_run) for(const auto& mod : modules_to_run)
{ {
// If all the modules have the same parameters, this would only need to run once // TODO If all the modules have the same parameters, this would only need to run once
auto names = mod->get_parameter_names(); auto names = mod->get_parameter_names();
pnames.insert(names.begin(), names.end()); pnames.insert(names.begin(), names.end());
} }
assert(pnames.size() < args.size()); assert(pnames.size() <= args.size());
std::unordered_map<std::string, argument> params; std::unordered_map<std::string, argument> params;
std::transform(pnames.begin(), std::transform(pnames.begin(),
pnames.end(), pnames.end(),
...@@ -125,7 +126,7 @@ struct select_module ...@@ -125,7 +126,7 @@ struct select_module
// TODO run multiple modules and split the parameter data to each batch size // TODO run multiple modules and split the parameter data to each batch size
auto results = run(modules_to_run.at(0), params); auto results = run(modules_to_run.at(0), params);
return argument{results}; return results.at(0);
} }
}; };
......
...@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs) ...@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
} }
template <class T> template <class T>
auto mod_compute_shape_op(rank<1>, auto mod_compute_shape_op(rank<2>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>, ...@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>,
return x.compute_shape(inputs, mod_args); return x.compute_shape(inputs, mod_args);
} }
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T> template <class T>
shape mod_compute_shape_op(rank<0>, shape mod_compute_shape_op(rank<0>,
const T& x, const T& x,
......
...@@ -2303,9 +2303,10 @@ TEST_CASE(select_module_dyn) ...@@ -2303,9 +2303,10 @@ TEST_CASE(select_module_dyn)
TEST_CASE(select_module_static) TEST_CASE(select_module_static)
{ {
migraphx::shape input{migraphx::shape::float_type, {3, 3, 255, 255}}; migraphx::shape input{migraphx::shape::float_type, {3, 3, 255, 255}};
migraphx::shape out_attr = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1000, 1000}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 1000}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 1000}},
migraphx::make_op("select_module", migraphx::make_op("select_module",
{{"output_dyn_shape", {{1, 4}, {1000, 1000}}}, {{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0}, {"output_batch_index", 0},
{"input_batch_index", 0}}), {"input_batch_index", 0}}),
input); input);
......
...@@ -6989,6 +6989,49 @@ TEST_CASE(scatternd_reduction_test) ...@@ -6989,6 +6989,49 @@ TEST_CASE(scatternd_reduction_test)
} }
} }
TEST_CASE(select_module_test)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins = submod->add_instruction(migraphx::make_op("squeeze"), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
migraphx::shape out_attr = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}};
mm->add_instruction(migraphx::make_op("select_module",
{{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0},
{"input_batch_index", 0},
{"dyn_batch_param_name", "data"}}),
{input},
{batch1, batch2, batch4});
p.compile(migraphx::ref::target{});
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sigmoid_test) TEST_CASE(sigmoid_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs) ...@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
} }
template <class T> template <class T>
auto mod_compute_shape_op(rank<1>, auto mod_compute_shape_op(rank<2>,
const T& x, const T& x,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const std::vector<module_ref>& mod_args)
...@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>, ...@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>,
return x.compute_shape(inputs, mod_args); return x.compute_shape(inputs, mod_args);
} }
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T> template <class T>
shape mod_compute_shape_op(rank<0>, shape mod_compute_shape_op(rank<0>,
const T& x, const T& x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment