Commit 5bc70a9c authored by charlie's avatar charlie
Browse files

Test created and works

parent 61b53e47
...@@ -91,6 +91,7 @@ add_library(migraphx ...@@ -91,6 +91,7 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_DYNAMIC_BATCH_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_DYNAMIC_BATCH_HPP #define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string> #include <string>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -36,9 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,9 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS {
* Split dynamic batch dimension over submodules if exactly one dimension in the parameter list * Split dynamic batch dimension over submodules if exactly one dimension in the parameter list
* is dynamic. Should only run on the main module. * is dynamic. Should only run on the main module.
*/ */
struct split_dynamic_batch struct split_single_dyn_dim
{ {
std::string name() const { return "split_dynamic_batch"; } std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager& p) const; void apply(module_pass_manager& p) const;
}; };
......
...@@ -94,7 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes, ...@@ -94,7 +94,7 @@ bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
* instruction to the top, replace return bypassing other instructions. Unused instructions should * instruction to the top, replace return bypassing other instructions. Unused instructions should
* be removed by dead_code_elimination * be removed by dead_code_elimination
*/ */
void split_dynamic_batch::apply(module_pass_manager& mpm) const void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{ {
module_ref mm; module_ref mm;
mm = &mpm.get_module(); mm = &mpm.get_module();
...@@ -117,7 +117,8 @@ void split_dynamic_batch::apply(module_pass_manager& mpm) const ...@@ -117,7 +117,8 @@ void split_dynamic_batch::apply(module_pass_manager& mpm) const
auto static_param = auto static_param =
submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens}); submod->add_parameter(dyn_param_name, migraphx::shape{dps.type(), static_lens});
map_ins[mm->get_parameter(dyn_param_name)] = static_param; map_ins[mm->get_parameter(dyn_param_name)] = static_param;
submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return; // redirect to select_module operator and return;
...@@ -126,10 +127,10 @@ void split_dynamic_batch::apply(module_pass_manager& mpm) const ...@@ -126,10 +127,10 @@ void split_dynamic_batch::apply(module_pass_manager& mpm) const
param_names.cend(), param_names.cend(),
std::back_inserter(sm_inputs), std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); }); [&](auto pn) { return mm->get_parameter(pn); });
auto sm_ins = mm->insert_instruction( migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()};
mm->begin(), auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(mm->get_output_shapes())}}), {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs, sm_inputs,
submodules); submodules);
mm->replace_return({sm_ins}); mm->replace_return({sm_ins});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::split_single_dyn_dim{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(dynamic_batch)
{
// slightly different from ref_ops_test in that the literal is copied over the submodules
// different compiler pass will pull the literals from the submodules to the main module
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{batch1, batch2, batch3, batch4});
mm0->add_return({sm_ins});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins =
mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment