Commit b41a56cf authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-proto

parents cf1172c8 1704bb04
...@@ -53,6 +53,7 @@ jobs: ...@@ -53,6 +53,7 @@ jobs:
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \ CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
-DMIGRAPHX_ENABLE_GPU=On \ -DMIGRAPHX_ENABLE_GPU=On \
-DMIGRAPHX_ENABLE_CPU=On \ -DMIGRAPHX_ENABLE_CPU=On \
-DMIGRAPHX_ENABLE_FPGA=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \ -DROCM_ENABLE_GH_ANNOTATIONS=On \
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \ -DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_CACHE=/data/tidy-cache \ -DCLANG_TIDY_CACHE=/data/tidy-cache \
......
name: MIGraphX Performance Tests name: MIGraphX Performance Tests
on: on:
push:
branches: [develop]
pull_request: pull_request:
branches: [develop] branches: [develop]
types: [opened, synchronize, closed]
schedule: schedule:
- cron: "0 5 * * 1-6" - cron: "0 5 * * 1-6"
......
...@@ -90,7 +90,6 @@ add_library(migraphx ...@@ -90,7 +90,6 @@ add_library(migraphx
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
...@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape); MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument); MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
...@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target); MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const size_t size() const
{ {
...@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
...@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments); MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const size_t size() const
{ {
...@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const size_t size() const
{ {
...@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(operation); MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction); MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions); MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -797,7 +797,7 @@ struct module; ...@@ -797,7 +797,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules); MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
...@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options); MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program); MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options); MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format // set file format
...@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options); MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options); MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -25,13 +25,10 @@ ...@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
...@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1))); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs( auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2))); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_input_1 = mmain->add_parameter( auto x_0 = mmain->add_parameter(
"input.1", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); "0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal( auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto x_main_module_5 = mmain->add_literal( auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto x_main_module_6 = mmain->add_literal( auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto x_main_module_7 = mmain->add_literal( auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto x_main_module_8 = mmain->add_literal( auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 7)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto x_main_module_9 = mmain->add_literal( auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto x_main_module_10 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_10 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto x_main_module_11 = mmain->add_literal( auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 10)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto x_main_module_12 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_12 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto x_main_module_13 = mmain->add_literal( auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto x_main_module_14 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_14 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto x_main_module_15 = mmain->add_literal( auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 14)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto x_main_module_16 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_16 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 15)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
auto x_main_module_17 = mmain->add_literal( auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 16)); migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
auto x_main_module_18 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_18 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 17)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
auto x_main_module_19 = mmain->add_literal( auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 18)); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction( auto x_main_module_20 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
migraphx::from_json_string( "4],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}")), x_0,
x_input_1,
x_main_module_18);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_op("broadcast",
migraphx::from_json_string("{axis:1,out_lens:[1,64,55,55]}")),
x_main_module_19); x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
auto x_main_module_22 = auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21); mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22); auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction( auto x_main_module_24 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_23); x_main_module_23);
auto x_main_module_25 = mmain->add_instruction( auto x_main_module_25 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}")),
x_main_module_24, x_main_module_24,
x_main_module_14); x_main_module_17);
auto x_main_module_26 = mmain->add_instruction( auto x_main_module_26 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
migraphx::from_json_string("{axis:1,out_lens:[1,192,27,27]}")),
x_main_module_15);
auto x_main_module_27 = auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26); mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27); auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction( auto x_main_module_29 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_28); x_main_module_28);
auto x_main_module_30 = mmain->add_instruction( auto x_main_module_30 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_29, x_main_module_29,
x_main_module_12); x_main_module_15);
auto x_main_module_31 = mmain->add_instruction( auto x_main_module_31 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
migraphx::from_json_string("{axis:1,out_lens:[1,384,13,13]}")),
x_main_module_13);
auto x_main_module_32 = auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31); mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction( auto x_main_module_34 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_33, x_main_module_33,
x_main_module_10); x_main_module_13);
auto x_main_module_35 = mmain->add_instruction( auto x_main_module_35 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_11);
auto x_main_module_36 = auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35); mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction( auto x_main_module_38 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op("convolution",
"convolution", "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::from_json_string( "1],use_dynamic_same_auto_pad:0}"),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}")),
x_main_module_37, x_main_module_37,
x_main_module_16); x_main_module_11);
auto x_main_module_39 = mmain->add_instruction( auto x_main_module_39 = mmain->add_instruction(
migraphx::make_op("broadcast", migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
migraphx::from_json_string("{axis:1,out_lens:[1,256,13,13]}")),
x_main_module_17);
auto x_main_module_40 = auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39); mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40); auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction( auto x_main_module_42 = mmain->add_instruction(
migraphx::make_op( migraphx::make_json_op(
"pooling", "pooling",
migraphx::from_json_string( "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}")),
x_main_module_41); x_main_module_41);
auto x_main_module_43 = mmain->add_instruction( auto x_main_module_43 =
migraphx::make_op("reshape", migraphx::from_json_string("{dims:[1,9216]}")), mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
x_main_module_42); auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto x_main_module_44 = mmain->add_instruction( auto x_main_module_45 = mmain->add_instruction(
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
x_main_module_6); auto x_main_module_46 =
auto x_main_module_45 = mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_43, x_main_module_44);
auto x_main_module_46 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")),
x_main_module_7);
auto x_main_module_47 = mmain->add_instruction( auto x_main_module_47 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
x_main_module_2); auto x_main_module_48 = mmain->add_instruction(
auto x_main_module_48 = migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_46, x_main_module_47);
auto x_main_module_49 = auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_45, x_main_module_48); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
auto x_main_module_50 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_49); auto x_main_module_50 =
auto x_main_module_51 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
x_main_module_4); auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
auto x_main_module_52 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_50, x_main_module_51);
auto x_main_module_53 = mmain->add_instruction( auto x_main_module_53 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
x_main_module_5); auto x_main_module_54 =
auto x_main_module_54 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,4096]}")), auto x_main_module_55 = mmain->add_instruction(
x_main_module_1); migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
auto x_main_module_55 = auto x_main_module_56 = mmain->add_instruction(
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_53, x_main_module_54); migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto x_main_module_56 = auto x_main_module_57 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_52, x_main_module_55); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
auto x_main_module_57 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_56); auto x_main_module_58 =
auto x_main_module_58 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
migraphx::make_op("transpose", migraphx::from_json_string("{permutation:[1,0]}")), auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
x_main_module_8);
auto x_main_module_59 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_57, x_main_module_58);
auto x_main_module_60 = mmain->add_instruction( auto x_main_module_60 = mmain->add_instruction(
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")), migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
x_main_module_9); auto x_main_module_61 =
auto x_main_module_61 = mmain->add_instruction( mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
migraphx::make_op("multibroadcast", migraphx::from_json_string("{out_lens:[1,1000]}")), auto x_main_module_62 = mmain->add_instruction(
x_main_module_0); migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto x_main_module_62 = auto x_main_module_63 = mmain->add_instruction(
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_60, x_main_module_61); migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
auto x_main_module_63 = auto x_main_module_64 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_59, x_main_module_62); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
mmain->add_return({x_main_module_63}); auto x_main_module_65 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
mmain->add_return({x_main_module_65});
return p; return p;
} }
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v) ...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return make_op_from_value(name, v); return make_op_from_value(name, v);
} }
operation make_json_op(const std::string& name, const std::string& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -35,17 +35,13 @@ struct onnx_options ...@@ -35,17 +35,13 @@ struct onnx_options
{ {
/// Old way to set default fixed dimension size /// Old way to set default fixed dimension size
std::size_t default_dim_value = 0; std::size_t default_dim_value = 0;
/*! /// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value /// parser throws)
* set parser throws)
*/
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*! /// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
* Explicitly specify dynamic dims of an input (if both map_input_dims and /// set parser throws)
* map_dyn_input_dims set parser throws)
*/
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {}; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
...@@ -53,6 +49,8 @@ struct onnx_options ...@@ -53,6 +49,8 @@ struct onnx_options
bool print_program_on_error = false; bool print_program_on_error = false;
/// Max iter num for the loop operator /// Max iter num for the loop operator
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
/// Use dynamic output for operators when available
bool use_dyn_output = false;
}; };
/// Create a program from an onnx file /// Create a program from an onnx file
......
...@@ -45,7 +45,15 @@ struct convert : unary<convert> ...@@ -45,7 +45,15 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; auto input = inputs.at(0);
if(input.dynamic())
{
return {target_type, input.dyn_dims()};
}
else
{
return {target_type, input.lens(), input.strides()};
}
} }
std::string point_op() const std::string point_op() const
......
...@@ -45,11 +45,13 @@ namespace op { ...@@ -45,11 +45,13 @@ namespace op {
struct nonmaxsuppression struct nonmaxsuppression
{ {
bool center_point_box = false; bool center_point_box = false;
bool use_dyn_output = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.center_point_box, "center_point_box")); return pack(f(self.center_point_box, "center_point_box"),
f(self.use_dyn_output, "use_dyn_output"));
} }
std::string name() const { return "nonmaxsuppression"; } std::string name() const { return "nonmaxsuppression"; }
...@@ -57,27 +59,81 @@ struct nonmaxsuppression ...@@ -57,27 +59,81 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this, true}.only_dims(3).same_ndims();
auto lens = inputs.front().lens(); auto boxes_max_lens = inputs.at(0).max_lens();
// num batches * num boxes
const auto max_num_boxes = boxes_max_lens.at(0) * boxes_max_lens.at(1);
// check input shape auto fixed_shape_error_check = [&]() {
if(lens[1] != inputs.at(1).lens()[2]) auto lens = inputs.front().lens();
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input");
}
if(lens[0] != inputs.at(1).lens()[0])
{
MIGRAPHX_THROW(
"NonMaxSuppression: number of batches mismatch between boxes and scores input");
}
};
if(use_dyn_output)
{ {
MIGRAPHX_THROW( if(inputs.at(0).dynamic())
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"); {
// both boxes and scores should be dynamic
// check dynamic dimensions are consistent
const auto boxes_dims = inputs.at(0).dyn_dims();
const auto scores_dims = inputs.at(1).dyn_dims();
if(boxes_dims.at(1) != scores_dims.at(2))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic spatial dimension mismatch between "
"boxes and scores input");
}
if(boxes_dims.at(0) != scores_dims.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: dynamic number of batches mismatch between "
"boxes and scores input");
}
}
else if(inputs.at(1).dynamic())
{
// scores has dynamic shape, boxes fixed shape
// check that it is only a dynamic number of classes
const auto scores_dims = inputs.at(1).dyn_dims();
const auto boxes_lens = inputs.at(0).lens();
if(not scores_dims.at(0).is_fixed() or scores_dims.at(0).max != boxes_lens.at(0))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; num_batches not "
"fixed or mismatched");
}
if(not scores_dims.at(2).is_fixed() or scores_dims.at(2).max != boxes_lens.at(1))
{
MIGRAPHX_THROW("NonMaxSuppression: scores dynamic num_classes; "
"spatial_dimension not fixed or mismatches");
}
}
else
{
fixed_shape_error_check();
}
std::vector<shape::dynamic_dimension> out_lens = {};
out_lens.push_back({0, max_num_boxes, 0});
out_lens.push_back({3, 3, 0});
return {shape::int64_type, out_lens};
} }
else
// check batch sizes
if(lens[0] != inputs.at(1).lens()[0])
{ {
MIGRAPHX_THROW( if(inputs.at(0).dynamic() or inputs.at(1).dynamic())
"NonMaxSuppression: number of batches mismatch between boxes and scores input"); {
MIGRAPHX_THROW(
"NonMaxSuppression: dynamic input shape with use_dyn_output set to false");
}
fixed_shape_error_check();
std::vector<std::size_t> out_lens = {max_num_boxes, 3};
return {shape::int64_type, out_lens};
} }
std::vector<int64_t> out_lens(2);
out_lens.at(0) = lens.at(1);
out_lens.at(1) = 3;
return {shape::int64_type, out_lens};
} }
struct box struct box
...@@ -181,13 +237,13 @@ struct nonmaxsuppression ...@@ -181,13 +237,13 @@ struct nonmaxsuppression
} }
template <class Output, class Boxes, class Scores> template <class Output, class Boxes, class Scores>
void compute_nms(Output output, std::size_t compute_nms(Output output,
Boxes boxes, Boxes boxes,
Scores scores, Scores scores,
const shape& output_shape, const shape& max_output_shape,
std::size_t max_output_boxes_per_class, std::size_t max_output_boxes_per_class,
double iou_threshold, double iou_threshold,
double score_threshold) const double score_threshold) const
{ {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
const auto& lens = scores.get_shape().lens(); const auto& lens = scores.get_shape().lens();
...@@ -197,7 +253,7 @@ struct nonmaxsuppression ...@@ -197,7 +253,7 @@ struct nonmaxsuppression
// boxes of a class with NMS applied [score, index] // boxes of a class with NMS applied [score, index]
std::vector<std::pair<double, int64_t>> selected_boxes_inside_class; std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices; std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements()); selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](auto idx) {
...@@ -237,11 +293,14 @@ struct nonmaxsuppression ...@@ -237,11 +293,14 @@ struct nonmaxsuppression
} }
}); });
std::copy(selected_indices.begin(), selected_indices.end(), output.begin()); std::copy(selected_indices.begin(), selected_indices.end(), output.begin());
return selected_indices.size() / 3;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; // make buffer of maximum size
shape max_output_shape = {output_shape.type(), output_shape.max_lens()};
argument result{max_output_shape};
std::size_t max_output_boxes_per_class = std::size_t max_output_boxes_per_class =
(args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0; (args.size() > 2) ? (args.at(2).at<std::size_t>()) : 0;
...@@ -249,22 +308,29 @@ struct nonmaxsuppression ...@@ -249,22 +308,29 @@ struct nonmaxsuppression
{ {
return result; return result;
} }
double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f; double iou_threshold = (args.size() > 3) ? (args.at(3).at<double>()) : 0.0f;
double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f; double score_threshold = (args.size() > 4) ? (args.at(4).at<double>()) : 0.0f;
std::size_t num_selected = 0;
result.visit([&](auto output) { result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto boxes, auto scores) { visit_all(args[0], args[1])([&](auto boxes, auto scores) {
compute_nms(output, num_selected = compute_nms(output,
boxes, boxes,
scores, scores,
output_shape, max_output_shape,
max_output_boxes_per_class, max_output_boxes_per_class,
iou_threshold, iou_threshold,
score_threshold); score_threshold);
}); });
}); });
if(use_dyn_output)
return result; {
return result.reshape({output_shape.type(), {num_selected, 3}});
}
else
{
return result;
}
} }
}; };
......
...@@ -21,16 +21,24 @@ ...@@ -21,16 +21,24 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
#include <migraphx/target_assignments.hpp> #include <unordered_set>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target) struct supported_segment
{ {
assignments.emplace(ins, target); std::unordered_set<instruction_ref> instructions;
} float metric;
};
using supported_segments = std::vector<supported_segment>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORTED_SEGMENTS_HPP
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -132,7 +134,7 @@ struct target ...@@ -132,7 +134,7 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const; supported_segments find_supported(const_module_ref mod, support_metric m) const;
// (optional) // (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
...@@ -224,10 +226,10 @@ struct target ...@@ -224,10 +226,10 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const supported_segments find_supported(const_module_ref mod, support_metric m) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m); return (*this).private_detail_te_get_handle().find_supported(mod, m);
} }
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
...@@ -261,33 +263,33 @@ struct target ...@@ -261,33 +263,33 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0; virtual supported_segments find_supported(const_module_ref mod, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T> template <class T>
static auto private_detail_te_default_is_supported(char, static auto private_detail_te_default_find_supported(char,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m)) -> decltype(private_detail_te_self.find_supported(mod, m))
{ {
return private_detail_te_self.is_supported(ins, m); return private_detail_te_self.find_supported(mod, m);
} }
template <class T> template <class T>
static float private_detail_te_default_is_supported(float, static supported_segments private_detail_te_default_find_supported(float,
T&& private_detail_te_self, T&& private_detail_te_self,
instruction_ref ins, const_module_ref mod,
support_metric m) support_metric m)
{ {
return target_is_supported(private_detail_te_self, ins, m); return target_find_supported(private_detail_te_self, mod, m);
} }
template <class T> template <class T>
...@@ -372,10 +374,11 @@ struct target ...@@ -372,10 +374,11 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override supported_segments find_supported(const_module_ref mod, support_metric m) const override
{ {
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m); return private_detail_te_default_find_supported(
char(0), private_detail_te_value, mod, m);
} }
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
......
...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments struct target_assignments
{ {
void add_assignment(instruction_ref ins, const std::string& target); using iterator = std::unordered_map<instruction_ref, std::string>::const_iterator;
using value_type = std::pair<instruction_ref, std::string>;
auto begin() const { return assignments.cbegin(); } auto size() const { return assignments.size(); }
auto end() const { return assignments.cend(); } auto& at(instruction_ref ins) const { return assignments.at(ins); }
auto insert(iterator it, const std::pair<instruction_ref, std::string>& assignment)
{
return assignments.insert(it, assignment);
}
auto find(instruction_ref ins) const { return assignments.find(ins); }
auto begin() const { return assignments.begin(); }
auto end() const { return assignments.end(); }
private: private:
std::unordered_map<instruction_ref, std::string> assignments; std::unordered_map<instruction_ref, std::string> assignments;
......
...@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v) ...@@ -64,5 +64,10 @@ operation make_op_from_value(const std::string& name, const value& v)
}); });
} }
operation make_json_op(const std::string& name, const std::string& s)
{
return make_op(name, from_json_string(convert_to_json(s)));
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -788,12 +788,15 @@ static std::string cpp_var_name(const std::string& name) ...@@ -788,12 +788,15 @@ static std::string cpp_var_name(const std::string& name)
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
os << "migraphx::make_op(" << enclose_name(op.name());
auto v = op.to_value(); auto v = op.to_value();
if(not v.empty()) if(not v.empty())
{ {
os << ", " os << "migraphx::make_json_op(" << enclose_name(op.name());
<< "migraphx::from_json_string(" << enclose_name(to_json_string(v)) << ")"; os << ", " << enclose_name(to_json_string(v));
}
else
{
os << "migraphx::make_op(" << enclose_name(op.name());
} }
os << ")"; os << ")";
} }
......
...@@ -97,6 +97,7 @@ struct onnx_parser ...@@ -97,6 +97,7 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
......
...@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -60,8 +60,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
parser.default_dyn_dim_value = options.default_dyn_dim_value; parser.default_dyn_dim_value = options.default_dyn_dim_value;
} }
if(not options.map_input_dims.empty() and not options.map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error) if(options.print_program_on_error)
{ {
...@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -80,6 +86,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
parser.parse_from(std::forward<Ts>(xs)...); parser.parse_from(std::forward<Ts>(xs)...);
} }
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -256,11 +256,6 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
......
...@@ -58,7 +58,6 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -58,7 +58,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Log", "log"}, {"Log", "log"},
{"LRN", "lrn"}, {"LRN", "lrn"},
{"Neg", "neg"}, {"Neg", "neg"},
{"NonMaxSuppression", "nonmaxsuppression"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
...@@ -75,7 +74,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -75,7 +74,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
{ {
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name); return contains({"flatten", "gather", "scatter"}, op_name);
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment