Unverified Commit e9e3eacc authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Split single dynamic dimension compiler pass (#1580)

Adds a new GPU compiler pass split_single_dyn_dim that handles when one input parameter has a single non-fixed dynamic_dimension.
commonly occurs for dynamic batch or BERT sequence length
Splits the dynamic shape into several submodules will static input parameters to handle all of the cases in the dynamic_dimension range.
Essentially does what I manually did for the select_module verify tests
Adds a compile option split_single_dyn_dim that toggles the pass on/off. Defaults to false.
Updates verify_program.hpp and run_verify.cpp to allow for the tests to change the compile_options
parent 32b9fd08
...@@ -91,6 +91,7 @@ add_library(migraphx ...@@ -91,6 +91,7 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -32,9 +32,17 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,9 +32,17 @@ inline namespace MIGRAPHX_INLINE_NS {
struct compile_options struct compile_options
{ {
/**
* Have MIGX allocate memory for parameters and add instructions
* to copy parameters and output to/from an offload device like a GPU.
*/
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true; bool fast_math = true;
bool exhaustive_tune = false; bool exhaustive_tune = false;
/// Use the split_single_dyn_dim pass
bool split_single_dyn_dim = false;
tracer trace{}; tracer trace{};
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct split_single_dyn_dim
{
std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager&) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
}
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
if(dd_check.has_value())
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
std::vector<instruction_ref> outputs(output_shapes.size());
for(size_t i = 0; i < output_shapes.size(); ++i)
{
outputs.at(i) =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
}
mm->replace_return(outputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -51,6 +51,7 @@ ...@@ -51,6 +51,7 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp> #include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
...@@ -101,6 +102,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -101,6 +102,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off // clang-format off
return return
{ {
enable_pass(options.split_single_dyn_dim, split_single_dyn_dim{}),
enable_pass(options.split_single_dyn_dim, dead_code_elimination{}),
normalize_ops{}, normalize_ops{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
......
...@@ -7705,8 +7705,8 @@ TEST_CASE(slice_test) ...@@ -7705,8 +7705,8 @@ TEST_CASE(slice_test)
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
{ {
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is too // Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
// large // too large
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {3, 3, 0}}};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::split_single_dyn_dim{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(dynamic_batch)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add0_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input);
submod->add_return({add0_ins, add1_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
migraphx::shape tmp_s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
sub_shapes.push_back(tmp_s);
sub_shapes.push_back(tmp_s);
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret0 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
auto ret1 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), sm_ins);
mm0->add_return({ret0, ret1});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add0_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
auto add1_ins = mm1->add_instruction(migraphx::make_op("add"), input1, input1);
mm1->add_return({add0_ins, add1_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -73,7 +73,8 @@ int main(int argc, const char* argv[]) ...@@ -73,7 +73,8 @@ int main(int argc, const char* argv[])
"test_if_literal", "test_if_literal",
"test_select_module_add", "test_select_module_add",
"test_select_module_reduce", "test_select_module_reduce",
"test_select_module_conv"}); "test_select_module_conv",
"test_split_single_dyn_dim"});
rv.disable_test_for("gpu", {"test_conv_bn_add"}); rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -67,15 +67,17 @@ inline void verify_load_save(const migraphx::program& p) ...@@ -67,15 +67,17 @@ inline void verify_load_save(const migraphx::program& p)
EXPECT(p == loaded); EXPECT(p == loaded);
} }
inline void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false) inline void compile_check(migraphx::program& p,
const migraphx::target& t,
migraphx::compile_options c_opts,
bool show_trace = false)
{ {
auto name = t.name(); auto name = t.name();
auto shapes = p.get_output_shapes(); auto shapes = p.get_output_shapes();
std::stringstream ss; std::stringstream ss;
migraphx::compile_options options;
if(show_trace) if(show_trace)
options.trace = migraphx::tracer{std::cout}; c_opts.trace = migraphx::tracer{std::cout};
p.compile(t, options); p.compile(t, c_opts);
if(shapes.size() != p.get_output_shapes().size()) if(shapes.size() != p.get_output_shapes().size())
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
...@@ -115,19 +117,23 @@ void run_verify::validate(const migraphx::target& t, ...@@ -115,19 +117,23 @@ void run_verify::validate(const migraphx::target& t,
} }
std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p, std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
migraphx::parameter_map inputs) const migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const
{ {
migraphx::target t = migraphx::make_target("ref"); migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
compile_check(p, t); compile_check(p, t, c_opts);
return p.eval(std::move(inputs)); return p.eval(std::move(inputs));
} }
std::pair<migraphx::program, std::vector<migraphx::argument>> run_verify::run_target( std::pair<migraphx::program, std::vector<migraphx::argument>>
const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const run_verify::run_target(const migraphx::target& t,
migraphx::program p,
const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const
{ {
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, (trace_target == t.name())); compile_check(p, t, c_opts, (trace_target == t.name()));
migraphx::parameter_map m; migraphx::parameter_map m;
for(auto&& input : inputs) for(auto&& input : inputs)
{ {
...@@ -157,7 +163,9 @@ auto get_hash(const T& x) ...@@ -157,7 +163,9 @@ auto get_hash(const T& x)
return std::hash<T>{}(x); return std::hash<T>{}(x);
} }
void run_verify::verify(const std::string& name, const migraphx::program& p) const void run_verify::verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const
{ {
using result_future = using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
...@@ -197,13 +205,13 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -197,13 +205,13 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
} }
} }
auto gold_f = detach_async([=] { return run_ref(p, m); }); auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
for(const auto& tname : target_names) for(const auto& tname : target_names)
{ {
target_info ti = get_target_info(tname); target_info ti = get_target_info(tname);
auto t = migraphx::make_target(tname); auto t = migraphx::make_target(tname);
results.emplace_back(tname, results.emplace_back(
detach_async([=] { return run_target(t, p, m); }, ti.parallel)); tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
} }
assert(gold_f.valid()); assert(gold_f.valid());
...@@ -244,7 +252,7 @@ void run_verify::run(int argc, const char* argv[]) const ...@@ -244,7 +252,7 @@ void run_verify::run(int argc, const char* argv[]) const
for(auto&& p : get_programs()) for(auto&& p : get_programs())
{ {
labels[p.section].push_back(p.name); labels[p.section].push_back(p.name);
test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); }); test::add_test_case(p.name, [=] { verify(p.name, p.get_program(), p.compile_options); });
} }
test::driver d{}; test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> { d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
......
...@@ -40,15 +40,19 @@ struct target_info ...@@ -40,15 +40,19 @@ struct target_info
struct run_verify struct run_verify
{ {
std::vector<migraphx::argument> run_ref(migraphx::program p, std::vector<migraphx::argument> run_ref(migraphx::program p,
migraphx::parameter_map inputs) const; migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const;
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>>
run_target(const migraphx::target& t, run_target(const migraphx::target& t,
migraphx::program p, migraphx::program p,
const migraphx::parameter_map& inputs) const; const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const;
void validate(const migraphx::target& t, void validate(const migraphx::target& t,
const migraphx::program& p, const migraphx::program& p,
const migraphx::parameter_map& m) const; const migraphx::parameter_map& m) const;
void verify(const std::string& name, const migraphx::program& p) const; void verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const;
void run(int argc, const char* argv[]) const; void run(int argc, const char* argv[]) const;
target_info get_target_info(const std::string& name) const; target_info get_target_info(const std::string& name) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
/**
* Test that the split_single_dyn_dim GPU compiler pass produces the same results as ref.
*/
struct test_split_single_dyn_dim : verify_program<test_split_single_dyn_dim>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input = mm->add_parameter("data", s);
auto broadcast_lit =
mm->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), input, broadcast_lit);
mm->add_return({add_ins});
return p;
}
migraphx::compile_options get_compile_options() const
{
migraphx::compile_options co;
co.split_single_dyn_dim = true;
return co;
};
};
...@@ -24,15 +24,17 @@ ...@@ -24,15 +24,17 @@
#ifndef MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP #ifndef MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP
#define MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP #define MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP
#include <functional>
#include <migraphx/auto_register.hpp> #include <migraphx/auto_register.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <functional> #include <migraphx/compile_options.hpp>
struct program_info struct program_info
{ {
std::string name; std::string name;
std::string section; std::string section;
std::function<migraphx::program()> get_program; std::function<migraphx::program()> get_program;
migraphx::compile_options compile_options;
}; };
void register_program_info(const program_info& pi); void register_program_info(const program_info& pi);
...@@ -48,6 +50,7 @@ struct register_verify_program_action ...@@ -48,6 +50,7 @@ struct register_verify_program_action
pi.name = migraphx::get_type_name<T>(); pi.name = migraphx::get_type_name<T>();
pi.section = x.section(); pi.section = x.section();
pi.get_program = [x] { return x.create_program(); }; pi.get_program = [x] { return x.create_program(); };
pi.compile_options = x.get_compile_options();
register_program_info(pi); register_program_info(pi);
} }
}; };
...@@ -59,6 +62,7 @@ template <class T> ...@@ -59,6 +62,7 @@ template <class T>
struct verify_program : auto_register_verify_program<T> struct verify_program : auto_register_verify_program<T>
{ {
std::string section() const { return "general"; }; std::string section() const { return "general"; };
migraphx::compile_options get_compile_options() const { return migraphx::compile_options{}; };
}; };
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment