Commit 9b3b10ed authored by Paul's avatar Paul
Browse files

Merge branch 'mlir-c' into mlir-c-sqlite

parents 68d86c3d 595532cd
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"if not os.path.exists(\"yolov4_fp16.mxr\"):\n", "if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n", " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16 --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.mxr\"):\n", "if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr" " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
] ]
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
...@@ -21,10 +22,10 @@ ...@@ -21,10 +22,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/json.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
...@@ -34,173 +35,189 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,173 +35,189 @@ inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::module_ref mmain = p.get_main_module();
auto m0 = auto x_main_module_0 = mmain->add_literal(migraphx::abs(
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto mx0 = mm->add_literal( auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto mx1 = mm->add_literal( auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto mx2 = mm->add_literal( auto x_input_1 = mmain->add_parameter(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2)); "input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx3 = mm->add_literal( auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal( auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal( auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal( auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal( auto x_main_module_8 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7));
auto mx8 = mm->add_literal( auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal( auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal( auto x_main_module_11 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal( auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal( auto x_main_module_13 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal( auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal( auto x_main_module_15 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal( auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15));
migraphx::op::convolution convolution16; auto x_main_module_17 = mmain->add_literal(
convolution16.padding = {2, 2}; migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16));
convolution16.stride = {4, 4}; auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal(
convolution16.dilation = {1, 1}; migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17));
convolution16.group = 1; auto x_main_module_19 = mmain->add_literal(
auto mx16 = mm->add_instruction(convolution16, m0, mx15); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18));
migraphx::op::broadcast broadcast17; auto x_main_module_20 = mmain->add_instruction(
broadcast17.axis = 1; migraphx::make_op(
broadcast17.broadcast_lens = {batch, 64, 55, 55}; "convolution",
auto mx17 = mm->add_instruction(broadcast17, mx14); migraphx::from_json_string(
migraphx::op::add add18; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")),
auto mx18 = mm->add_instruction(add18, mx16, mx17); x_input_1,
migraphx::op::relu relu19; x_main_module_18);
auto mx19 = mm->add_instruction(relu19, mx18); auto x_main_module_21 = mmain->add_instruction(
migraphx::op::pooling pooling20; migraphx::make_op("broadcast",
pooling20.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
pooling20.padding = {0, 0}; x_main_module_19);
pooling20.stride = {2, 2}; auto x_main_module_22 =
pooling20.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto mx20 = mm->add_instruction(pooling20, mx19); auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
migraphx::op::convolution convolution21; auto x_main_module_24 = mmain->add_instruction(
convolution21.padding = {2, 2}; migraphx::make_op(
convolution21.stride = {1, 1}; "pooling",
convolution21.dilation = {1, 1}; migraphx::from_json_string(
convolution21.group = 1; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
auto mx21 = mm->add_instruction(convolution21, mx20, mx13); x_main_module_23);
migraphx::op::broadcast broadcast22; auto x_main_module_25 = mmain->add_instruction(
broadcast22.axis = 1; migraphx::make_op(
broadcast22.broadcast_lens = {batch, 192, 27, 27}; "convolution",
auto mx22 = mm->add_instruction(broadcast22, mx12); migraphx::from_json_string(
migraphx::op::add add23; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
auto mx23 = mm->add_instruction(add23, mx21, mx22); x_main_module_24,
migraphx::op::relu relu24; x_main_module_14);
auto mx24 = mm->add_instruction(relu24, mx23); auto x_main_module_26 = mmain->add_instruction(
migraphx::op::pooling pooling25; migraphx::make_op("broadcast",
pooling25.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
pooling25.padding = {0, 0}; x_main_module_15);
pooling25.stride = {2, 2}; auto x_main_module_27 =
pooling25.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto mx25 = mm->add_instruction(pooling25, mx24); auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
migraphx::op::convolution convolution26; auto x_main_module_29 = mmain->add_instruction(
convolution26.padding = {1, 1}; migraphx::make_op(
convolution26.stride = {1, 1}; "pooling",
convolution26.dilation = {1, 1}; migraphx::from_json_string(
convolution26.group = 1; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
auto mx26 = mm->add_instruction(convolution26, mx25, mx11); x_main_module_28);
migraphx::op::broadcast broadcast27; auto x_main_module_30 = mmain->add_instruction(
broadcast27.axis = 1; migraphx::make_op(
broadcast27.broadcast_lens = {batch, 384, 13, 13}; "convolution",
auto mx27 = mm->add_instruction(broadcast27, mx10); migraphx::from_json_string(
migraphx::op::add add28; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx28 = mm->add_instruction(add28, mx26, mx27); x_main_module_29,
migraphx::op::relu relu29; x_main_module_12);
auto mx29 = mm->add_instruction(relu29, mx28); auto x_main_module_31 = mmain->add_instruction(
migraphx::op::convolution convolution30; migraphx::make_op("broadcast",
convolution30.padding = {1, 1}; migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
convolution30.stride = {1, 1}; x_main_module_13);
convolution30.dilation = {1, 1}; auto x_main_module_32 =
convolution30.group = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto mx30 = mm->add_instruction(convolution30, mx29, mx9); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
migraphx::op::broadcast broadcast31; auto x_main_module_34 = mmain->add_instruction(
broadcast31.axis = 1; migraphx::make_op(
broadcast31.broadcast_lens = {batch, 256, 13, 13}; "convolution",
auto mx31 = mm->add_instruction(broadcast31, mx8); migraphx::from_json_string(
migraphx::op::add add32; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx32 = mm->add_instruction(add32, mx30, mx31); x_main_module_33,
migraphx::op::relu relu33; x_main_module_10);
auto mx33 = mm->add_instruction(relu33, mx32); auto x_main_module_35 = mmain->add_instruction(
migraphx::op::convolution convolution34; migraphx::make_op("broadcast",
convolution34.padding = {1, 1}; migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
convolution34.stride = {1, 1}; x_main_module_11);
convolution34.dilation = {1, 1}; auto x_main_module_36 =
convolution34.group = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto mx34 = mm->add_instruction(convolution34, mx33, mx7); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
migraphx::op::broadcast broadcast35; auto x_main_module_38 = mmain->add_instruction(
broadcast35.axis = 1; migraphx::make_op(
broadcast35.broadcast_lens = {batch, 256, 13, 13}; "convolution",
auto mx35 = mm->add_instruction(broadcast35, mx6); migraphx::from_json_string(
migraphx::op::add add36; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
auto mx36 = mm->add_instruction(add36, mx34, mx35); x_main_module_37,
migraphx::op::relu relu37; x_main_module_16);
auto mx37 = mm->add_instruction(relu37, mx36); auto x_main_module_39 = mmain->add_instruction(
migraphx::op::pooling pooling38; migraphx::make_op("broadcast",
pooling38.mode = migraphx::op::pooling_mode::max; migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
pooling38.padding = {0, 0}; x_main_module_17);
pooling38.stride = {2, 2}; auto x_main_module_40 =
pooling38.lengths = {3, 3}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto mx38 = mm->add_instruction(pooling38, mx37); auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
migraphx::op::flatten flatten39; auto x_main_module_42 = mmain->add_instruction(
flatten39.axis = 1; migraphx::make_op(
auto mx39 = mm->add_instruction(flatten39, mx38); "pooling",
migraphx::op::identity identity40; migraphx::from_json_string(
auto mx40 = mm->add_instruction(identity40, mx39); "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
migraphx::op::transpose transpose41; x_main_module_41);
transpose41.dims = {1, 0}; auto x_main_module_43 = mmain->add_instruction(
auto mx41 = mm->add_instruction(transpose41, mx5); migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")),
migraphx::op::multibroadcast multibroadcast42; x_main_module_42);
multibroadcast42.output_lens = {batch, 4096}; auto x_main_module_44 = mmain->add_instruction(
auto mx42 = mm->add_instruction(multibroadcast42, mx4); migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
float dot43_alpha = 1; x_main_module_6);
float dot43_beta = 1; auto x_main_module_45 =
auto mx43 = migraphx::add_apply_alpha_beta( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta); auto x_main_module_46 = mmain->add_instruction(
migraphx::op::relu relu44; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
auto mx44 = mm->add_instruction(relu44, mx43); x_main_module_7);
migraphx::op::identity identity45; auto x_main_module_47 = mmain->add_instruction(
auto mx45 = mm->add_instruction(identity45, mx44); migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
migraphx::op::transpose transpose46; x_main_module_2);
transpose46.dims = {1, 0}; auto x_main_module_48 =
auto mx46 = mm->add_instruction(transpose46, mx3); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
migraphx::op::multibroadcast multibroadcast47; auto x_main_module_49 =
multibroadcast47.output_lens = {batch, 4096}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48);
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49);
float dot48_alpha = 1; auto x_main_module_51 = mmain->add_instruction(
float dot48_beta = 1; migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
auto mx48 = migraphx::add_apply_alpha_beta( x_main_module_4);
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta); auto x_main_module_52 =
migraphx::op::relu relu49; mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
auto mx49 = mm->add_instruction(relu49, mx48); auto x_main_module_53 = mmain->add_instruction(
migraphx::op::transpose transpose50; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
transpose50.dims = {1, 0}; x_main_module_5);
auto mx50 = mm->add_instruction(transpose50, mx1); auto x_main_module_54 = mmain->add_instruction(
migraphx::op::multibroadcast multibroadcast51; migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
multibroadcast51.output_lens = {batch, 1000}; x_main_module_1);
auto mx51 = mm->add_instruction(multibroadcast51, mx0); auto x_main_module_55 =
float dot52_alpha = 1; mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54);
float dot52_beta = 1; auto x_main_module_56 =
migraphx::add_apply_alpha_beta( mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55);
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta); auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56);
auto x_main_module_58 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_9);
auto x_main_module_61 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_0);
auto x_main_module_62 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61);
auto x_main_module_63 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62);
mmain->add_return({x_main_module_63});
return p; return p;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -210,6 +210,9 @@ struct loader ...@@ -210,6 +210,9 @@ struct loader
auto last = std::prev(mm->end(), trim); auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
} }
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize) if(optimize)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(*p.get_main_module(),
......
This diff is collapsed.
...@@ -142,7 +142,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins, ...@@ -142,7 +142,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
input_map[input] = map_ins[param]; input_map[input] = map_ins[param];
} }
} }
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins)); pm->replace_return(pm->insert_instructions(last, xm, map_ins));
return inputs; return inputs;
} }
......
...@@ -81,8 +81,9 @@ struct basic_iota_iterator ...@@ -81,8 +81,9 @@ struct basic_iota_iterator
index--; index--;
return it; return it;
} }
// TODO: operator->
reference operator*() const { return f(index); } reference operator*() const { return f(index); }
pointer operator->() const { return &f(index); }
reference operator[](int n) const { return f(index + n); }
}; };
template <class T, class F> template <class T, class F>
......
...@@ -120,9 +120,33 @@ struct module ...@@ -120,9 +120,33 @@ struct module
instruction_ref move_instructions(instruction_ref src, instruction_ref dst); instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref> std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins, add_instructions(const std::vector<instruction_ref>& instructions,
const_module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts> template <class... Ts>
instruction_ref add_literal(Ts&&... xs) instruction_ref add_literal(Ts&&... xs)
...@@ -179,7 +203,9 @@ struct module ...@@ -179,7 +203,9 @@ struct module
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const; print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
......
...@@ -56,14 +56,21 @@ struct nonmaxsuppression ...@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens(); auto lens = inputs.front().lens();
// check input shape // check input shape
if(lens[1] != inputs.at(1).lens()[2]) if(lens[1] != inputs.at(1).lens()[2])
{ {
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!"); MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
} }
std::vector<int64_t> out_lens(2); std::vector<int64_t> out_lens(2);
...@@ -74,8 +81,8 @@ struct nonmaxsuppression ...@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct box struct box
{ {
std::array<float, 2> x; std::array<double, 2> x;
std::array<float, 2> y; std::array<double, 2> y;
void sort() void sort()
{ {
...@@ -83,9 +90,9 @@ struct nonmaxsuppression ...@@ -83,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end()); std::sort(y.begin(), y.end());
} }
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; } std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const double area() const
{ {
assert(std::is_sorted(x.begin(), x.end())); assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end())); assert(std::is_sorted(y.begin(), y.end()));
...@@ -94,29 +101,29 @@ struct nonmaxsuppression ...@@ -94,29 +101,29 @@ struct nonmaxsuppression
}; };
template <class T> template <class T>
box batch_box(const T* boxes, std::size_t bidx) const box batch_box(T boxes, std::size_t box_idx) const
{ {
box result{}; box result{};
const T* start = boxes + 4 * bidx; auto start = boxes + 4 * box_idx;
if(center_point_box) if(center_point_box)
{ {
float half_width = start[2] / 2.0f; double half_width = start[2] / 2.0;
float half_height = start[3] / 2.0f; double half_height = start[3] / 2.0;
float x_center = start[0]; double x_center = start[0];
float y_center = start[1]; double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width}; result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height}; result.y = {y_center - half_height, y_center + half_height};
} }
else else
{ {
result.x = {start[1], start[3]}; result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {start[0], start[2]}; result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
} }
return result; return result;
} }
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{ {
b1.sort(); b1.sort();
b2.sort(); b2.sort();
...@@ -128,7 +135,7 @@ struct nonmaxsuppression ...@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]); intersection[i][1] = std::min(b1[i][1], b2[i][1]);
} }
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y}; std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end()); return not std::is_sorted(bx.begin(), bx.end());
})) }))
...@@ -136,115 +143,124 @@ struct nonmaxsuppression ...@@ -136,115 +143,124 @@ struct nonmaxsuppression
return false; return false;
} }
const float area1 = b1.area(); const double area1 = b1.area();
const float area2 = b2.area(); const double area2 = b2.area();
const float intersection_area = intersection.area(); const double intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area; const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f) if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{ {
return false; return false;
} }
const float intersection_over_union = intersection_area / union_area; const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold; return intersection_over_union > iou_threshold;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const // filter boxes below score_threshold
template <class T>
std::priority_queue<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{ {
argument result{output_shape}; std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap =
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); }); make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
std::size_t max_output_boxes_per_class = 0; transform_if(
float iou_threshold = 0.0f; scores_start,
float score_threshold = 0.0f; scores_start + num_boxes,
insert_to_boxes_heap,
if(args.size() > 2) [&](auto sc) {
{ box_idx++;
max_output_boxes_per_class = args.at(2).at<std::size_t>(); return sc >= score_threshold;
} },
// max_output_boxes_per_class is 0, no output [&](auto sc) { return std::make_pair(sc, box_idx - 1); });
if(max_output_boxes_per_class == 0) return boxes_heap;
{ }
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class; template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(output_shape.elements());
// iterate over batches and classes
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>()); shape comp_s{shape::double_type, {num_batches, num_classes}};
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0]; auto batch_idx = idx[0];
auto cidx = idx[1]; auto class_idx = idx[1];
// index offset for this class
std::size_t score_offset = (bidx * class_num + cidx) * box_num; auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
const float* batch_boxes = boxes + bidx * box_num * 4; // iterator to first value of this batch
std::priority_queue<std::pair<float, int64_t>> sorted_boxes; auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto insert_to_sorted_boxes = auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear(); selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold // Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() && while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class) selected_boxes_inside_class.size() < max_output_boxes_per_class)
{ {
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top(); // Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
// Check with existing selected boxes for this class, suppress if exceed the IOU const auto next_top_score = boxes_heap.top();
// (Intersection Over Union) threshold bool not_selected =
bool not_selected = std::any_of( std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.begin(), selected_boxes_inside_class.end(),
selected_boxes_inside_class.end(), [&](auto selected_index) {
[&](auto selected_index) { return this->suppress_by_iou(
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second), batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes, selected_index.second), batch_box(batch_boxes_start, selected_index.second),
iou_threshold); iou_threshold);
}); });
if(not not_selected) if(not not_selected)
{ {
selected_boxes_inside_class.push_back(next_top_score); selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx); selected_indices.push_back(batch_idx);
selected_indices.push_back(cidx); selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second); selected_indices.push_back(next_top_score.second);
} }
sorted_boxes.pop(); boxes_heap.pop();
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) { std::size_t max_output_boxes_per_class =
std::copy(selected_indices.begin(), selected_indices.end(), out.begin()); (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0)
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
}); });
return result; return result;
......
...@@ -44,8 +44,8 @@ auto with_char(F f) ...@@ -44,8 +44,8 @@ auto with_char(F f)
return [=](unsigned char c) -> bool { return f(c); }; return [=](unsigned char c) -> bool { return f(c); };
} }
inline std::string inline void
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string_inplace(std::string& subject, const std::string& search, const std::string& replace)
{ {
size_t pos = 0; size_t pos = 0;
while((pos = subject.find(search, pos)) != std::string::npos) while((pos = subject.find(search, pos)) != std::string::npos)
...@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string ...@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
subject.replace(pos, search.length(), replace); subject.replace(pos, search.length(), replace);
pos += replace.length(); pos += replace.length();
} }
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
{
replace_string_inplace(subject, search, replace);
return subject; return subject;
} }
......
...@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond) ...@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{ {
const auto& mod_inputs = ins->module_inputs(); const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1); module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod); auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs(); auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size()); assert(mod_outputs.size() >= ins_outputs.size());
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -196,6 +197,62 @@ void module::assign(const module& m) ...@@ -196,6 +197,62 @@ void module::assign(const module& m)
} }
} }
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args) instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{ {
return insert_instruction(impl->instructions.end(), op, std::move(args)); return insert_instruction(impl->instructions.end(), op, std::move(args));
...@@ -335,54 +392,51 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d ...@@ -335,54 +392,51 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
} }
std::vector<instruction_ref> std::vector<instruction_ref>
module::insert_module_instructions(instruction_ref ins, module::add_instructions(const std::vector<instruction_ref>& instructions,
const_module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
std::vector<instruction_ref> mod_outputs; return this->insert_instructions(this->end(), instructions, std::move(map_ins));
for(auto sins : iterator_for(*m)) }
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return") std::vector<instruction_ref>
{ module::add_instructions(const_module_ref m,
mod_outputs = copy_inputs; std::unordered_map<instruction_ref, instruction_ref> map_ins)
break; {
} return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args); std::vector<instruction_ref>
} module::add_instructions(instruction_ref start,
map_ins[sins] = copy_ins; instruction_ref last,
} std::unordered_map<instruction_ref, instruction_ref> map_ins)
if(mod_outputs.empty()) {
mod_outputs = {map_ins.at(std::prev(m->end()))}; return this->insert_instructions(this->end(), start, last, std::move(map_ins));
return mod_outputs; }
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
} }
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l)
...@@ -708,44 +762,33 @@ void module::print_graph(std::ostream& os, bool brief) const ...@@ -708,44 +762,33 @@ void module::print_graph(std::ostream& os, bool brief) const
os << "}" << std::endl; os << "}" << std::endl;
} }
static std::string cpp_var_name(const std::string& name) static std::string to_c_id(const std::string& name, char rep = '_')
{ {
return "m" + replace_string(name, "@", "x"); std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(contains(id, "__"))
replace_string_inplace(id, "__", "_");
return id;
} }
static std::string cpp_op_var(const std::string& name, instruction_ref ins) static std::string cpp_var_name(const std::string& name)
{ {
return replace_string(name, "@", ins->name()); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
std::string x = to_string(op); os << "migraphx::make_op(" << enclose_name(op.name());
if(contains(x, "[")) auto v = op.to_value();
if(not v.empty())
{ {
auto start = x.find('['); os << ", "
auto end = x.find(']'); << "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")";
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
} }
os << ")";
} }
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...@@ -758,22 +801,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -758,22 +801,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
} }
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const module::print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; // cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print( names = this->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins); std::vector<std::string> input_vars;
if(ins->name().front() != '@') std::transform(ins->inputs().begin(),
{ ins->inputs().end(),
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl; std::back_inserter(input_vars),
print_op_attributes(os, op, ins->get_operator()); [&](auto input) { return cpp_var_name(ins_names.at(input)); });
} if(ins != last)
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = "; os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
os << "p.add_literal("; os << mname << "->add_literal(";
bool use_abs = false; bool use_abs = false;
ins->get_literal().visit([&](auto v) { ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; }); use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
...@@ -791,17 +837,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str ...@@ -791,17 +837,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter; std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ","; os << mname << "->add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape()); print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl; os << ");" << std::endl;
} }
else if(ins->name() == "@return")
{
os << mname << "->add_return({";
os << join_strings(input_vars, ", ");
os << "});" << std::endl;
}
else else
{ {
os << "p.add_instruction(" << op; assert(ins->name().front() != '@');
for(auto input : ins->inputs()) os << mname << "->add_instruction(";
{ print_make_op(os, ins->get_operator());
os << ", " << cpp_var_name(ins_names.at(input)); os << ", " << join_strings(input_vars, ", ");
}
os << ");" << std::endl; os << ");" << std::endl;
} }
}, },
...@@ -810,7 +861,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str ...@@ -810,7 +861,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
return names; return names;
} }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
......
...@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const ...@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const
{ {
auto vec_modules = this->get_modules(); auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
os << "migraphx::program p;\n";
for(auto& mod : vec_modules) for(auto& mod : vec_modules)
{ {
os << "module: \"" << mod->name() << "\"" << std::endl; std::string var_name = "m" + mod->name();
names = mod->print_cpp(os, names); os << "migraphx::module_ref " << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module();";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_cpp(os, var_name, names);
os << std::endl; os << std::endl;
} }
} }
......
...@@ -108,7 +108,7 @@ struct find_conv_pointwise ...@@ -108,7 +108,7 @@ struct find_conv_pointwise
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
mm->add_return(mm->insert_module_instructions(mm->end(), pm, param_map)); mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(), std::copy_if(ins->inputs().begin(),
......
...@@ -132,7 +132,7 @@ migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& i ...@@ -132,7 +132,7 @@ migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& i
bool verify_mlir(const migraphx::module& mmlir) bool verify_mlir(const migraphx::module& mmlir)
{ {
migraphx::program ref; migraphx::program ref;
ref.get_main_module()->insert_module_instructions(ref.get_main_module()->end(), &mmlir); ref.get_main_module()->insert_instructions(ref.get_main_module()->end(), &mmlir);
auto inputs = generate_params(ref); auto inputs = generate_params(ref);
......
...@@ -300,6 +300,96 @@ TEST_CASE(parameter_name_order) ...@@ -300,6 +300,96 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1); EXPECT(param_names == names1);
} }
TEST_CASE(insert_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), {x1});
m1.add_instruction(migraphx::make_op("add"), {sqrt, x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.insert_instructions(sqrt, &m2, {{x2, x1}});
EXPECT(std::prev(sqrt)->name() == "sqrt");
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(&m2, {{x2, x1}});
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_range)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(sqrt2, m2.end(), {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_vector)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions({sqrt2}, {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
struct check_for_pass_op struct check_for_pass_op
{ {
bool* found = nullptr; bool* found = nullptr;
......
...@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test) ...@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
TEST_CASE(nms_transpose1_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 4, 6}};
std::vector<float> boxes_vec = {
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6, 0.4, 10.5, 10.6, 100.5,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto t_boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
transpose_boxes,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_transpose2_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {4, 1, 6}};
std::vector<float> boxes_vec = {
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6, 0.4, 10.5, 10.6, 100.5,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto t_boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
transpose_boxes,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment