Commit 9550f6e9 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_nms' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_topK

parents d7fb5892 0d8c92a0
......@@ -25,16 +25,8 @@
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
......
......@@ -26,6 +26,7 @@
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
......
......@@ -24,16 +24,9 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_SINH_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,6 +26,7 @@
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
......
......@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
......
......@@ -44,13 +44,21 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims;
parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val == 0)
if(dim_val != 0)
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0})
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value");
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
......
......@@ -257,6 +257,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
......
......@@ -216,7 +216,7 @@ static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& da
std::fill(data_vals.begin(), data_vals.end(), data[0]);
}
else
copy(data.begin(), data.end(), std::back_inserter(data_vals));
copy(data.begin(), data.end(), data_vals.begin());
return data_vals;
}
......@@ -329,33 +329,37 @@ void tf_parser::parse_node(const std::string& name)
auto&& node = nodes.at(name);
if(not is_valid_op(node))
return;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
// control dependencies (signified by ^ before the name) are ignored
if(contains(input, "^"))
continue;
if(nodes.count(input) > 0)
std::string input_name = input;
// if input has trailing `:0` index then remove it
auto multi_out_idx = input.find(':');
if(multi_out_idx != std::string::npos && input.substr(multi_out_idx + 1) == "0")
{
input_name = input.substr(0, multi_out_idx);
}
if(nodes.count(input_name) > 0)
{
std::string iname;
// input was from a node with multiple outputs
if(contains(input, ':'))
if(contains(input_name, ':'))
{
iname = input.substr(0, input.find(':'));
input_name = input_name.substr(0, input.find(':'));
}
else
{
iname = get_name(nodes.at(input));
input_name = get_name(nodes.at(input_name));
}
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(input));
assert(name != input_name);
this->parse_node(input_name);
args.push_back(instructions.at(input_name));
}
else
{
args.push_back(instructions.at(input));
args.push_back(instructions.at(input_name));
}
}
std::vector<instruction_ref> result;
......
......@@ -3109,7 +3109,7 @@ def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node(
'Max',
......@@ -3243,7 +3243,7 @@ def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node(
'Min',
......
 max-example:e
max_test:a

0
1
23"Max test-dropoutZ
23"Maxmax_testZ
0

......@@ -15,7 +15,7 @@

b
2
3

B
\ No newline at end of file
B
\ No newline at end of file
 min-example:e
min_test:a

0
1
23"Min test-dropoutZ
23"Minmin_testZ
0

......@@ -15,7 +15,7 @@

b
2
3

B
\ No newline at end of file
B
\ No newline at end of file
......@@ -2874,7 +2874,9 @@ TEST_CASE(max_test)
auto l0 = mm->add_instruction(migraphx::make_op("max"), input0, input1);
mm->add_instruction(migraphx::make_op("max"), l0, input2);
optimize_onnx("max_test.onnx");
auto prog = optimize_onnx("max_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(maxpool_notset_test)
......@@ -2989,7 +2991,9 @@ TEST_CASE(min_test)
auto l0 = mm->add_instruction(migraphx::make_op("min"), input0, input1);
mm->add_instruction(migraphx::make_op("min"), l0, input2);
optimize_onnx("min_test.onnx");
auto prog = optimize_onnx("min_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(multinomial_test)
......@@ -5573,6 +5577,26 @@ TEST_CASE(variable_batch_user_input_test4)
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test5)
{
// Error using default_dim_value and default_dyn_dim_value
migraphx::onnx_options options;
options.default_dim_value = 2;
options.default_dyn_dim_value = {1, 2, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_user_input_test6)
{
// Error using both map_dyn_input_dims and map_input_dims
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
options.map_input_dims["0"] = {2, 3, 16, 16};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_leq_zero_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment