Commit a2de3140 authored by Paul's avatar Paul
Browse files

Merge

parents 2c69352c 2783c649
......@@ -80,7 +80,7 @@
"outputs": [],
"source": [
"if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16 --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
]
......
/*
* The MIT License (MIT)
*
......@@ -21,10 +22,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/json.hpp>
#include "models.hpp"
namespace migraphx {
......@@ -34,173 +35,189 @@ inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto m0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx0 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
auto mx1 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
auto mx2 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
auto mx3 = mm->add_literal(
migraphx::module_ref mmain = p.get_main_module();
auto x_main_module_0 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_input_1 = mmain->add_parameter(
"input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal(
auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal(
auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
auto mx8 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal(
auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6));
auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7));
auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8));
auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal(
auto x_main_module_11 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10));
auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal(
auto x_main_module_13 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12));
auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
migraphx::op::convolution convolution16;
convolution16.padding = {2, 2};
convolution16.stride = {4, 4};
convolution16.dilation = {1, 1};
convolution16.group = 1;
auto mx16 = mm->add_instruction(convolution16, m0, mx15);
migraphx::op::broadcast broadcast17;
broadcast17.axis = 1;
broadcast17.broadcast_lens = {batch, 64, 55, 55};
auto mx17 = mm->add_instruction(broadcast17, mx14);
migraphx::op::add add18;
auto mx18 = mm->add_instruction(add18, mx16, mx17);
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
auto mx20 = mm->add_instruction(pooling20, mx19);
migraphx::op::convolution convolution21;
convolution21.padding = {2, 2};
convolution21.stride = {1, 1};
convolution21.dilation = {1, 1};
convolution21.group = 1;
auto mx21 = mm->add_instruction(convolution21, mx20, mx13);
migraphx::op::broadcast broadcast22;
broadcast22.axis = 1;
broadcast22.broadcast_lens = {batch, 192, 27, 27};
auto mx22 = mm->add_instruction(broadcast22, mx12);
migraphx::op::add add23;
auto mx23 = mm->add_instruction(add23, mx21, mx22);
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
auto mx25 = mm->add_instruction(pooling25, mx24);
migraphx::op::convolution convolution26;
convolution26.padding = {1, 1};
convolution26.stride = {1, 1};
convolution26.dilation = {1, 1};
convolution26.group = 1;
auto mx26 = mm->add_instruction(convolution26, mx25, mx11);
migraphx::op::broadcast broadcast27;
broadcast27.axis = 1;
broadcast27.broadcast_lens = {batch, 384, 13, 13};
auto mx27 = mm->add_instruction(broadcast27, mx10);
migraphx::op::add add28;
auto mx28 = mm->add_instruction(add28, mx26, mx27);
migraphx::op::relu relu29;
auto mx29 = mm->add_instruction(relu29, mx28);
migraphx::op::convolution convolution30;
convolution30.padding = {1, 1};
convolution30.stride = {1, 1};
convolution30.dilation = {1, 1};
convolution30.group = 1;
auto mx30 = mm->add_instruction(convolution30, mx29, mx9);
migraphx::op::broadcast broadcast31;
broadcast31.axis = 1;
broadcast31.broadcast_lens = {batch, 256, 13, 13};
auto mx31 = mm->add_instruction(broadcast31, mx8);
migraphx::op::add add32;
auto mx32 = mm->add_instruction(add32, mx30, mx31);
migraphx::op::relu relu33;
auto mx33 = mm->add_instruction(relu33, mx32);
migraphx::op::convolution convolution34;
convolution34.padding = {1, 1};
convolution34.stride = {1, 1};
convolution34.dilation = {1, 1};
convolution34.group = 1;
auto mx34 = mm->add_instruction(convolution34, mx33, mx7);
migraphx::op::broadcast broadcast35;
broadcast35.axis = 1;
broadcast35.broadcast_lens = {batch, 256, 13, 13};
auto mx35 = mm->add_instruction(broadcast35, mx6);
migraphx::op::add add36;
auto mx36 = mm->add_instruction(add36, mx34, mx35);
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
auto mx38 = mm->add_instruction(pooling38, mx37);
migraphx::op::flatten flatten39;
flatten39.axis = 1;
auto mx39 = mm->add_instruction(flatten39, mx38);
migraphx::op::identity identity40;
auto mx40 = mm->add_instruction(identity40, mx39);
migraphx::op::transpose transpose41;
transpose41.dims = {1, 0};
auto mx41 = mm->add_instruction(transpose41, mx5);
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
auto mx45 = mm->add_instruction(identity45, mx44);
migraphx::op::transpose transpose46;
transpose46.dims = {1, 0};
auto mx46 = mm->add_instruction(transpose46, mx3);
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
transpose50.dims = {1, 0};
auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
auto x_main_module_15 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14));
auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15));
auto x_main_module_17 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16));
auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17));
auto x_main_module_19 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18));
auto x_main_module_20 = mmain->add_instruction(
migraphx::make_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")),
x_input_1,
x_main_module_18);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
x_main_module_19);
auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction(
migraphx::make_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_23);
auto x_main_module_25 = mmain->add_instruction(
migraphx::make_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
x_main_module_24,
x_main_module_14);
auto x_main_module_26 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
x_main_module_15);
auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction(
migraphx::make_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_28);
auto x_main_module_30 = mmain->add_instruction(
migraphx::make_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_29,
x_main_module_12);
auto x_main_module_31 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
x_main_module_13);
auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction(
migraphx::make_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_33,
x_main_module_10);
auto x_main_module_35 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_11);
auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction(
migraphx::make_op(
"convolution",
migraphx::from_json_string(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_37,
x_main_module_16);
auto x_main_module_39 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_17);
auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction(
migraphx::make_op(
"pooling",
migraphx::from_json_string(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_41);
auto x_main_module_43 = mmain->add_instruction(
migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")),
x_main_module_42);
auto x_main_module_44 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_6);
auto x_main_module_45 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
auto x_main_module_46 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_7);
auto x_main_module_47 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_2);
auto x_main_module_48 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48);
auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49);
auto x_main_module_51 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_4);
auto x_main_module_52 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
auto x_main_module_53 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_5);
auto x_main_module_54 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_1);
auto x_main_module_55 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54);
auto x_main_module_56 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55);
auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56);
auto x_main_module_58 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")),
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_9);
auto x_main_module_61 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")),
x_main_module_0);
auto x_main_module_62 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61);
auto x_main_module_63 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62);
mmain->add_return({x_main_module_63});
return p;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -210,6 +210,9 @@ struct loader
auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end());
}
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize)
{
migraphx::run_passes(*p.get_main_module(),
......
This diff is collapsed.
......@@ -142,7 +142,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
pm->replace_return(pm->insert_instructions(last, xm, map_ins));
return inputs;
}
......
......@@ -81,8 +81,9 @@ struct basic_iota_iterator
index--;
return it;
}
// TODO: operator->
reference operator*() const { return f(index); }
pointer operator->() const { return &f(index); }
reference operator[](int n) const { return f(index + n); }
};
template <class T, class F>
......
......@@ -120,9 +120,33 @@ struct module
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
......@@ -179,7 +203,9 @@ struct module
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const;
print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
......
......@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const
{
// requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens();
// check input shape
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!");
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
}
std::vector<int64_t> out_lens(2);
......@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct box
{
std::array<float, 2> x;
std::array<float, 2> y;
std::array<double, 2> x;
std::array<double, 2> y;
void sort()
{
......@@ -83,9 +90,9 @@ struct nonmaxsuppression
std::sort(y.begin(), y.end());
}
std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
std::array<double, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }
float area() const
double area() const
{
assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end()));
......@@ -94,29 +101,29 @@ struct nonmaxsuppression
};
template <class T>
box batch_box(const T* boxes, std::size_t bidx) const
box batch_box(T boxes, std::size_t box_idx) const
{
box result{};
const T* start = boxes + 4 * bidx;
auto start = boxes + 4 * box_idx;
if(center_point_box)
{
float half_width = start[2] / 2.0f;
float half_height = start[3] / 2.0f;
float x_center = start[0];
float y_center = start[1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
double half_width = start[2] / 2.0;
double half_height = start[3] / 2.0;
double x_center = start[0];
double y_center = start[1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
}
else
{
result.x = {start[1], start[3]};
result.y = {start[0], start[2]};
result.x = {static_cast<double>(start[1]), static_cast<double>(start[3])};
result.y = {static_cast<double>(start[0]), static_cast<double>(start[2])};
}
return result;
}
inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const
inline bool suppress_by_iou(box b1, box b2, double iou_threshold) const
{
b1.sort();
b2.sort();
......@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection[i][1] = std::min(b1[i][1], b2[i][1]);
}
std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y};
std::vector<std::array<double, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end());
}))
......@@ -136,115 +143,124 @@ struct nonmaxsuppression
return false;
}
const float area1 = b1.area();
const float area2 = b2.area();
const float intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area;
const double area1 = b1.area();
const double area2 = b2.area();
const double intersection_area = intersection.area();
const double union_area = area1 + area2 - intersection_area;
if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{
return false;
}
const float intersection_over_union = intersection_area / union_area;
const double intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
// filter boxes below score_threshold
template <class T>
std::priority_queue<std::pair<double, int64_t>>
filter_boxes_by_score(T scores_start, std::size_t num_boxes, double score_threshold) const
{
argument result{output_shape};
result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); });
std::size_t max_output_boxes_per_class = 0;
float iou_threshold = 0.0f;
float score_threshold = 0.0f;
if(args.size() > 2)
{
max_output_boxes_per_class = args.at(2).at<std::size_t>();
}
// max_output_boxes_per_class is 0, no output
if(max_output_boxes_per_class == 0)
{
return result;
}
if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}
if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}
const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];
std::priority_queue<std::pair<double, int64_t>> boxes_heap;
auto insert_to_boxes_heap =
make_function_output_iterator([&](const auto& x) { boxes_heap.push(x); });
int64_t box_idx = 0;
transform_if(
scores_start,
scores_start + num_boxes,
insert_to_boxes_heap,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
return boxes_heap;
}
std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
template <class Output, class Boxes, class Scores>
void compute_nms(Output output,
Boxes boxes,
Scores scores,
const shape& output_shape,
std::size_t max_output_boxes_per_class,
double iou_threshold,
double score_threshold) const
{
std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens();
const auto num_batches = lens[0];
const auto num_classes = lens[1];
const auto num_boxes = lens[2];
// boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());
auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>());
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
// iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0];
auto cidx = idx[1];
std::size_t score_offset = (bidx * class_num + cidx) * box_num;
const float* batch_boxes = boxes + bidx * box_num * 4;
std::priority_queue<std::pair<float, int64_t>> sorted_boxes;
auto insert_to_sorted_boxes =
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
auto batch_idx = idx[0];
auto class_idx = idx[1];
// index offset for this class
auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
// iterator to first value of this batch
auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() &&
while(!boxes_heap.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top();
// Check with existing selected boxes for this class, suppress if exceed the IOU
// (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second),
batch_box(batch_boxes, selected_index.second),
iou_threshold);
});
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const auto next_top_score = boxes_heap.top();
bool not_selected =
std::any_of(selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(
batch_box(batch_boxes_start, next_top_score.second),
batch_box(batch_boxes_start, selected_index.second),
iou_threshold);
});
if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx);
selected_indices.push_back(cidx);
selected_indices.push_back(batch_idx);
selected_indices.push_back(class_idx);
selected_indices.push_back(next_top_score.second);
}
sorted_boxes.pop();
boxes_heap.pop();
}
});
std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto out) {
std::copy(selected_indices.begin(), selected_indices.end(), out.begin());
std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
if(max_output_boxes_per_class == 0)
{
return result;
}
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output,
boxes,
scores,
output_shape,
max_output_boxes_per_class,
iou_threshold,
score_threshold);
});
});
return result;
......
......@@ -44,8 +44,8 @@ auto with_char(F f)
return [=](unsigned char c) -> bool { return f(c); };
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
inline void
replace_string_inplace(std::string& subject, const std::string& search, const std::string& replace)
{
size_t pos = 0;
while((pos = subject.find(search, pos)) != std::string::npos)
......@@ -53,6 +53,12 @@ replace_string(std::string subject, const std::string& search, const std::string
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
{
replace_string_inplace(subject, search, replace);
return subject;
}
......
......@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -35,6 +35,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -196,6 +197,62 @@ void module::assign(const module& m)
}
}
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......@@ -335,54 +392,48 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
}
std::vector<instruction_ref>
module::insert_module_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
module::add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
return this->insert_instructions(this->end(), instructions, std::move(map_ins));
}
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
std::vector<instruction_ref>
module::add_instructions(const_module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
std::vector<instruction_ref>
module::add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), start, last, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref> module::insert_instructions(
instruction_ref ins, const_module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
}
instruction_ref module::add_literal(literal l)
......@@ -708,44 +759,33 @@ void module::print_graph(std::ostream& os, bool brief) const
os << "}" << std::endl;
}
static std::string cpp_var_name(const std::string& name)
static std::string to_c_id(const std::string& name, char rep = '_')
{
return "m" + replace_string(name, "@", "x");
std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(contains(id, "__"))
replace_string_inplace(id, "__", "_");
return id;
}
static std::string cpp_op_var(const std::string& name, instruction_ref ins)
static std::string cpp_var_name(const std::string& name)
{
return replace_string(name, "@", ins->name());
return to_c_id("x_" + replace_string(name, ":", "_module_"));
}
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op)
static void print_make_op(std::ostream& os, const operation& op)
{
std::string x = to_string(op);
if(contains(x, "["))
os << "migraphx::make_op(" << enclose_name(op.name());
auto v = op.to_value();
if(not v.empty())
{
auto start = x.find('[');
auto end = x.find(']');
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
os << ", "
<< "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")";
}
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
......@@ -758,22 +798,25 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
}
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
module::print_cpp(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
os << "migraphx::module p;" << std::endl;
unsigned long seed = 0;
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
os << mname << "->add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
......@@ -791,17 +834,22 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
os << mname << "->add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << "->add_return({";
os << join_strings(input_vars, ", ");
os << "});" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(ins_names.at(input));
}
assert(ins->name().front() != '@');
os << mname << "->add_instruction(";
print_make_op(os, ins->get_operator());
os << ", " << join_strings(input_vars, ", ");
os << ");" << std::endl;
}
},
......@@ -810,7 +858,7 @@ module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::str
return names;
}
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
......
......@@ -790,10 +790,17 @@ void program::print_cpp(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "migraphx::program p;\n";
for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print_cpp(os, names);
std::string var_name = "m" + mod->name();
os << "migraphx::module_ref " << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module();";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_cpp(os, var_name, names);
os << std::endl;
}
}
......
......@@ -300,6 +300,96 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1);
}
TEST_CASE(insert_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), {x1});
m1.add_instruction(migraphx::make_op("add"), {sqrt, x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.insert_instructions(sqrt, &m2, {{x2, x1}});
EXPECT(std::prev(sqrt)->name() == "sqrt");
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(&m2, {{x2, x1}});
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_range)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(sqrt2, m2.end(), {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_vector)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions({sqrt2}, {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
struct check_for_pass_op
{
bool* found = nullptr;
......
......@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_transpose1_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 4, 6}};
std::vector<float> boxes_vec = {
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6, 0.4, 10.5, 10.6, 100.5,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto t_boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
transpose_boxes,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_transpose2_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {4, 1, 6}};
std::vector<float> boxes_vec = {
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6, 0.4, 10.5, 10.6, 100.5,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto t_boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto transpose_boxes = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t_boxes_l);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
transpose_boxes,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nonzero_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment