Commit 0b0a6d4f authored by charlie's avatar charlie
Browse files

First pass on the operator

only works if exact batch size submodule present
will need to make it assemble from other sizes later
parent 65590e8e
......@@ -181,13 +181,13 @@ struct context
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
(*this).private_detail_te_get_handle().wait_for(std::move(queue));
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
(*this).private_detail_te_get_handle().finish_on(std::move(queue));
}
void finish() const
......@@ -260,29 +260,29 @@ struct context
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
-> decltype(private_detail_te_self.wait_for(std::move(queue)))
{
private_detail_te_self.wait_for(queue);
private_detail_te_self.wait_for(std::move(queue));
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
wait_for_context(private_detail_te_self, std::move(queue));
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
-> decltype(private_detail_te_self.finish_on(std::move(queue)))
{
private_detail_te_self.finish_on(queue);
private_detail_te_self.finish_on(std::move(queue));
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
finish_on_context(private_detail_te_self, std::move(queue));
}
template <typename PrivateDetailTypeErasedT>
......@@ -302,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(value)
: private_detail_te_value(std::move(value))
{
}
......@@ -334,13 +334,13 @@ struct context
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
private_detail_te_default_wait_for(char(0), private_detail_te_value, std::move(queue));
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
private_detail_te_default_finish_on(char(0), private_detail_te_value, std::move(queue));
}
void finish() const override { private_detail_te_value.finish(); }
......
......@@ -52,8 +52,8 @@ struct select_module
std::string name() const { return "select_module"; }
// this should run once during model compilation with dynamic shape input
// run once on each model evaluation with static shape input
// runs once during model compilation with dynamic shape input
// may run on each model evaluation with static shape input
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
......@@ -95,7 +95,7 @@ struct select_module
{
// find submodule with the same parameter shape as the input data
auto p_shape = mod->get_parameter_shape(dyn_batch_param_name);
if(p_shape == dyn_out.computed_shape)
if(p_shape == args.at(0).get_shape())
{
modules_to_run.push_back(mod);
break;
......@@ -105,17 +105,18 @@ struct select_module
if(modules_to_run.empty())
{
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found");
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for input shape: " +
migraphx::to_string(args.at(0).get_shape()));
}
std::set<std::string> pnames;
for(const auto& mod : modules_to_run)
{
// If all the modules have the same parameters, this would only need to run once
// TODO If all the modules have the same parameters, this would only need to run once
auto names = mod->get_parameter_names();
pnames.insert(names.begin(), names.end());
}
assert(pnames.size() < args.size());
assert(pnames.size() <= args.size());
std::unordered_map<std::string, argument> params;
std::transform(pnames.begin(),
pnames.end(),
......@@ -125,7 +126,7 @@ struct select_module
// TODO run multiple modules and split the parameter data to each batch size
auto results = run(modules_to_run.at(0), params);
return argument{results};
return results.at(0);
}
};
......
......@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
}
template <class T>
auto mod_compute_shape_op(rank<1>,
auto mod_compute_shape_op(rank<2>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
......@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>,
return x.compute_shape(inputs, mod_args);
}
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T>
shape mod_compute_shape_op(rank<0>,
const T& x,
......
......@@ -2303,9 +2303,10 @@ TEST_CASE(select_module_dyn)
TEST_CASE(select_module_static)
{
migraphx::shape input{migraphx::shape::float_type, {3, 3, 255, 255}};
migraphx::shape out_attr = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1000, 1000}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 1000}},
migraphx::make_op("select_module",
{{"output_dyn_shape", {{1, 4}, {1000, 1000}}},
{{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0},
{"input_batch_index", 0}}),
input);
......
......@@ -6989,6 +6989,49 @@ TEST_CASE(scatternd_reduction_test)
}
}
TEST_CASE(select_module_test)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins = submod->add_instruction(migraphx::make_op("squeeze"), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
migraphx::shape out_attr = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}};
mm->add_instruction(migraphx::make_op("select_module",
{{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0},
{"input_batch_index", 0},
{"dyn_batch_param_name", "data"}}),
{input},
{batch1, batch2, batch4});
p.compile(migraphx::ref::target{});
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sigmoid_test)
{
migraphx::program p;
......
......@@ -166,7 +166,7 @@ shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
}
template <class T>
auto mod_compute_shape_op(rank<1>,
auto mod_compute_shape_op(rank<2>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
......@@ -175,6 +175,15 @@ auto mod_compute_shape_op(rank<1>,
return x.compute_shape(inputs, mod_args);
}
template <class T>
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>&) -> decltype(x.compute_shape(inputs))
{
return x.compute_shape(inputs);
}
template <class T>
shape mod_compute_shape_op(rank<0>,
const T& x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment