Commit 9550f6e9 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_nms' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_topK

parents d7fb5892 0d8c92a0
...@@ -25,16 +25,8 @@ ...@@ -25,16 +25,8 @@
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP #define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array> #include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to // Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op. // the previously existing Scatter op.
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp> #include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -24,16 +24,9 @@ ...@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_SINH_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_SINH_HPP #define MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <algorithm> #include <algorithm>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP #define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -44,13 +44,21 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -44,13 +44,21 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.map_dyn_input_dims = options.map_dyn_input_dims; parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val == 0) if(dim_val != 0)
{ {
parser.default_dyn_dim_value = options.default_dyn_dim_value; if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0})
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value");
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = options.default_dyn_dim_value;
} }
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
......
...@@ -257,6 +257,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -257,6 +257,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -216,7 +216,7 @@ static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& da ...@@ -216,7 +216,7 @@ static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& da
std::fill(data_vals.begin(), data_vals.end(), data[0]); std::fill(data_vals.begin(), data_vals.end(), data[0]);
} }
else else
copy(data.begin(), data.end(), std::back_inserter(data_vals)); copy(data.begin(), data.end(), data_vals.begin());
return data_vals; return data_vals;
} }
...@@ -329,33 +329,37 @@ void tf_parser::parse_node(const std::string& name) ...@@ -329,33 +329,37 @@ void tf_parser::parse_node(const std::string& name)
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
if(not is_valid_op(node)) if(not is_valid_op(node))
return; return;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
// control dependencies (signified by ^ before the name) are ignored // control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^")) if(contains(input, "^"))
continue; continue;
if(nodes.count(input) > 0) std::string input_name = input;
// if input has trailing `:0` index then remove it
auto multi_out_idx = input.find(':');
if(multi_out_idx != std::string::npos && input.substr(multi_out_idx + 1) == "0")
{
input_name = input.substr(0, multi_out_idx);
}
if(nodes.count(input_name) > 0)
{ {
std::string iname;
// input was from a node with multiple outputs // input was from a node with multiple outputs
if(contains(input, ':')) if(contains(input_name, ':'))
{ {
iname = input.substr(0, input.find(':')); input_name = input_name.substr(0, input.find(':'));
} }
else else
{ {
iname = get_name(nodes.at(input)); input_name = get_name(nodes.at(input_name));
} }
assert(name != iname); assert(name != input_name);
this->parse_node(iname); this->parse_node(input_name);
args.push_back(instructions.at(input)); args.push_back(instructions.at(input_name));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input_name));
} }
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
......
...@@ -3109,7 +3109,7 @@ def max_test(): ...@@ -3109,7 +3109,7 @@ def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node( node = onnx.helper.make_node(
'Max', 'Max',
...@@ -3243,7 +3243,7 @@ def min_test(): ...@@ -3243,7 +3243,7 @@ def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node( node = onnx.helper.make_node(
'Min', 'Min',
......
 max-example:e max_test:a
 
0 0
1 1
23"Max test-dropoutZ 23"Maxmax_testZ
0 0
 
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
 
b b
2 3
 
B B
\ No newline at end of file \ No newline at end of file
 min-example:e min_test:a
 
0 0
1 1
23"Min test-dropoutZ 23"Minmin_testZ
0 0
 
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
 
b b
2 3
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -2874,7 +2874,9 @@ TEST_CASE(max_test) ...@@ -2874,7 +2874,9 @@ TEST_CASE(max_test)
auto l0 = mm->add_instruction(migraphx::make_op("max"), input0, input1); auto l0 = mm->add_instruction(migraphx::make_op("max"), input0, input1);
mm->add_instruction(migraphx::make_op("max"), l0, input2); mm->add_instruction(migraphx::make_op("max"), l0, input2);
optimize_onnx("max_test.onnx"); auto prog = optimize_onnx("max_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(maxpool_notset_test) TEST_CASE(maxpool_notset_test)
...@@ -2989,7 +2991,9 @@ TEST_CASE(min_test) ...@@ -2989,7 +2991,9 @@ TEST_CASE(min_test)
auto l0 = mm->add_instruction(migraphx::make_op("min"), input0, input1); auto l0 = mm->add_instruction(migraphx::make_op("min"), input0, input1);
mm->add_instruction(migraphx::make_op("min"), l0, input2); mm->add_instruction(migraphx::make_op("min"), l0, input2);
optimize_onnx("min_test.onnx"); auto prog = optimize_onnx("min_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(multinomial_test) TEST_CASE(multinomial_test)
...@@ -5573,6 +5577,26 @@ TEST_CASE(variable_batch_user_input_test4) ...@@ -5573,6 +5577,26 @@ TEST_CASE(variable_batch_user_input_test4)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(variable_batch_user_input_test5)
{
// Error using default_dim_value and default_dyn_dim_value
migraphx::onnx_options options;
options.default_dim_value = 2;
options.default_dyn_dim_value = {1, 2, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_user_input_test6)
{
// Error using both map_dyn_input_dims and map_input_dims
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
options.map_input_dims["0"] = {2, 3, 16, 16};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_leq_zero_test) TEST_CASE(variable_batch_leq_zero_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment