Commit b8ebf8ad authored by charlie's avatar charlie
Browse files

progress on gpu version

parent b50aa56c
......@@ -33,8 +33,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// Make this work just for exact matches
// can get rid of the other attributes and just check all the parameters are the same
// GPU version of this might have to deal with output parameters
// see loop op for how the output parameters are dealt with there
// Can have multiple inputs but only one output?
......@@ -45,12 +43,46 @@ struct select_module
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_dyn_shape, "output_dyn_shape"));
return pack(f(self.output_dyn_shapes, "output_dyn_shapes"));
}
std::string name() const { return "select_module"; }
shape compute_shape(std::vector<shape>) const { return shape{output_dyn_shapes}; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
// if(std::none_of(inputs.cbegin(), inputs.cend(), [](auto input){ return input.dynamic();
// }))
//{
// if(mods.size() != 1)
// {
// MIGRAPHX_THROW("SELECT_MODULE: operator should have one submodule during eval.");
// }
// return {mods.front()->get_output_shapes()};
//}
return shape{output_dyn_shapes};
}
std::vector<std::string> get_input_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return not contains(pn, "#output_"); });
return ret;
}
std::vector<std::string> get_output_parameter_names(module_ref mod) const
{
auto param_names = mod->get_parameter_names();
std::vector<std::string> ret;
std::copy_if(param_names.cbegin(),
param_names.cend(),
std::back_inserter(ret),
[](auto pn) { return contains(pn, "#output_"); });
return ret;
}
argument compute(const shape&,
const std::vector<argument>& args,
......@@ -58,15 +90,17 @@ struct select_module
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// find submodule with parameter shapes exactly the same as the input arguments
// assuming arguments are in the same order as the parameters
// find submodule with input parameter shapes exactly the same as the input arguments
// assuming arguments are in the same order as the input parameters
auto module_iter =
std::find_if(submodule_list.cbegin(), submodule_list.cend(), [&](module_ref mr) {
auto param_names = mr->get_parameter_names();
std::equal(
args.cbegin(), args.cend(), param_names.cbegin(), [&](auto a, auto p_name) {
return a.get_shape() == mr->get_parameter_shape(p_name);
});
auto input_param_names = get_input_parameter_names(mr);
return std::equal(args.cbegin(),
args.cend(),
input_param_names.cbegin(),
[&](auto a, auto p_name) {
return a.get_shape() == mr->get_parameter_shape(p_name);
});
});
if(module_iter == submodule_list.end())
......@@ -74,14 +108,27 @@ struct select_module
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
}
auto module_to_run = *module_iter;
auto param_names = module_to_run->get_parameter_names();
assert(pnames.size() <= args.size());
std::unordered_map<std::string, argument> params;
std::transform(param_names.begin(),
param_names.end(),
args.begin(),
// add input parameters
auto input_param_names = get_input_parameter_names(module_to_run);
assert(input_param_names.size() <= args.size());
std::transform(input_param_names.cbegin(),
input_param_names.cend(),
args.cbegin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& a) { return std::make_pair(name, a); });
// add output parameter (empty if on ref)
// assuming the order of the output parameters is in the same order as input parameters
// need to set up the buffers for the output parameters
auto output_param_names = get_output_parameter_names(module_to_run);
assert(output_param_names.size() <= args.size());
std::transform(output_param_names.cbegin(),
output_param_names.cend(),
args.cbegin(),
std::inserter(params, params.end()),
[](auto&& name, auto&& arg) { return std::make_pair(name, arg); });
[](auto&& name, auto&& a) { return std::make_pair(name, a); });
auto results = run(module_to_run, params);
return argument{results};
......
......@@ -243,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool dynamic() const;
/// Return true if this shape or any of the sub_shapes are dynamic
bool any_of_dynamic() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......
......@@ -380,7 +380,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic())
if(not ins->get_shape().any_of_dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
}
......
......@@ -483,6 +483,17 @@ std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::any_of_dynamic() const
{
if(this->dynamic())
{
return true;
}
return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
return s.any_of_dynamic();
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
......
......@@ -76,6 +76,11 @@ struct gpu_loop
}
}
/**
* This finds the output parameters for a module and returns a map between the parameter name
* and output argument indicies. Needs to have the module names mapped to the correct input
* parameters to begin with; not sure where those indices are set.
*/
std::unordered_map<std::string, int> get_output_params(const module& m) const
{
auto get_output_index = [](const std::string& name) {
......
......@@ -2297,25 +2297,11 @@ TEST_CASE(select_module_dyn)
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1000, 1000}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1000, 1000}}},
out_attr,
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
input);
}
TEST_CASE(select_module_static)
{
migraphx::shape input{migraphx::shape::float_type, {3, 3, 255, 255}};
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1000, 1000}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 1000}},
migraphx::make_op("select_module",
{{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0},
{"input_batch_index", 0}}),
input);
}
TEST_CASE(slice_shape)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
......
......@@ -6989,7 +6989,7 @@ TEST_CASE(scatternd_reduction_test)
}
}
TEST_CASE(select_module_test)
TEST_CASE(select_module_test0)
{
migraphx::program p;
......@@ -7000,25 +7000,28 @@ TEST_CASE(select_module_test)
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins = submod->add_instruction(migraphx::make_op("squeeze"), reduce_ins);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
migraphx::shape out_attr = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}};
mm->add_instruction(migraphx::make_op("select_module",
{{"output_dyn_shape", migraphx::to_value(out_attr)},
{"output_batch_index", 0},
{"input_batch_index", 0},
{"dyn_batch_param_name", "data"}}),
{input},
{batch1, batch2, batch4});
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::ref::target{});
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4};
......@@ -7032,6 +7035,52 @@ TEST_CASE(select_module_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(select_module_test1)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::ref::target{});
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4, -4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sigmoid_test)
{
migraphx::program p;
......
......@@ -185,7 +185,16 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
if(x.second.dynamic())
{
// create static shape using maximum dimensions
migraphx::shape static_shape{x.second.type(), x.second.max_lens()};
m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first));
}
else
{
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
}
}
auto gold_f = detach_async([=] { return run_ref(p, m); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_select_module : verify_program<test_select_module>
{
migraphx::program create_program() const
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment