Commit 2ee0f9e8 authored by kahmed10's avatar kahmed10 Committed by mvermeulen
Browse files

Mlperf models resnet50 and mobilenetv1 support (#406)



* initial testing

* add shape op

* formatting

* add env variable for batch sizes

* formatting

* progress on driver

* progress on driver

* cleanup

* cleanup

* add and modified prev tests

* formatting

* remove comment

* add shape op test

* formatting

* manually insert shape op in test

* formatting

* create options struct for parsers

* formatting

* Add documentation for python

* Fix c++ documentaion

* add documentation to parser

* formatting

* add argmin and tests

* fix doc and definitions

* formatting

* revert test functions

* formatting
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 7c35bf11
...@@ -5,11 +5,15 @@ set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen/) ...@@ -5,11 +5,15 @@ set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen/)
add_doxygen_doc( add_doxygen_doc(
OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT} OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT}
INPUT INPUT
${CMAKE_CURRENT_SOURCE_DIR}/../src ${PROJECT_SOURCE_DIR}/src
INCLUDE_PATH INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/../src/include ${PROJECT_SOURCE_DIR}/src/include
${CMAKE_CURRENT_SOURCE_DIR}/../src/targets/cpu/include ${PROJECT_SOURCE_DIR}/src/targets/cpu/include
${CMAKE_CURRENT_SOURCE_DIR}/../src/targets/gpu/include ${PROJECT_SOURCE_DIR}/src/targets/gpu/include
STRIP_FROM_INC_PATH
${PROJECT_SOURCE_DIR}/src/include
${PROJECT_SOURCE_DIR}/src/targets/cpu/include
${PROJECT_SOURCE_DIR}/src/targets/gpu/include
SEARCH_INCLUDES YES SEARCH_INCLUDES YES
MACRO_EXPANSION YES MACRO_EXPANSION YES
RECURSIVE YES RECURSIVE YES
......
User Guide C++ User Guide
========== ==============
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
...@@ -10,4 +10,5 @@ User Guide ...@@ -10,4 +10,5 @@ User Guide
reference/operators reference/operators
reference/program reference/program
reference/targets reference/targets
reference/quantization
reference/pass reference/pass
...@@ -10,7 +10,8 @@ Welcome to AMD MIGraphX's documentation! ...@@ -10,7 +10,8 @@ Welcome to AMD MIGraphX's documentation!
:maxdepth: 3 :maxdepth: 3
:caption: Contents: :caption: Contents:
user_guide py_user_guide
cpp_user_guide
developer_guide developer_guide
......
Python User Guide
=================
.. toctree::
:maxdepth: 2
:caption: Contents:
reference/py
...@@ -6,7 +6,11 @@ operation ...@@ -6,7 +6,11 @@ operation
.. doxygenstruct:: migraphx::operation .. doxygenstruct:: migraphx::operation
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::is_context_free
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::has_finalize
operators operators
--------- ---------
.. doxygenfile:: operators.hpp .. doxygennamespace:: migraphx::op
...@@ -11,15 +11,10 @@ dead_code_elimination ...@@ -11,15 +11,10 @@ dead_code_elimination
.. doxygenstruct:: migraphx::dead_code_elimination .. doxygenstruct:: migraphx::dead_code_elimination
common_subexpression_elimination eliminate_common_subexpression
-------------------------------- ------------------------------
.. doxygenstruct:: migraphx::common_subexpression_elimination .. doxygenstruct:: migraphx::eliminate_common_subexpression
constant_propagate
------------------
.. doxygenstruct:: migraphx::constant_propagate
eliminate_concat eliminate_concat
---------------- ----------------
...@@ -31,10 +26,35 @@ eliminate_contiguous ...@@ -31,10 +26,35 @@ eliminate_contiguous
.. doxygenstruct:: migraphx::eliminate_contiguous .. doxygenstruct:: migraphx::eliminate_contiguous
fwd_conv_batchnorm_rewrite eliminate_identity
-------------------------- ------------------
.. doxygenstruct:: migraphx::eliminate_identity
eliminate_pad
-------------
.. doxygenstruct:: migraphx::eliminate_pad
propagate_constant
------------------
.. doxygenstruct:: migraphx::propagate_constant
rewrite_batchnorm
-----------------
.. doxygenstruct:: migraphx::rewrite_batchnorm
rewrite_rnn
-----------
.. doxygenstruct:: migraphx::rewrite_rnn
schedule
--------
.. doxygenstruct:: migraphx::fwd_conv_batchnorm_rewrite .. doxygenstruct:: migraphx::schedule
simplify_algebra simplify_algebra
---------------- ----------------
......
...@@ -22,3 +22,18 @@ parse_onnx ...@@ -22,3 +22,18 @@ parse_onnx
---------- ----------
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::parse_onnx .. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::parse_onnx
parse_tf
--------
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::parse_tf
onnx_options
------------
.. doxygenstruct:: migraphx::onnx_options
tf_options
----------
.. doxygenstruct:: migraphx::tf_options
.. py:module:: migraphx
Python Reference
================
shape
-----
.. py:class:: shape(type, lens, strides=None)
Describes the shape of a tensor. This includes size, layout, and data type/
.. py:method:: type()
An integer that represents the type
:rtype: int
.. py:method:: lens()
A list of the lengths of the shape
:rtype: list[int]
.. py:method:: strides()
A list of the strides of the shape
:rtype: list[int]
.. py:method:: elements()
The number of elements in the shape
:rtype: int
.. py:method:: bytes()
The number of bytes the shape uses
:rtype: int
.. py:method:: type_size()
The number of bytes one element uses
:rtype: int
.. py:method:: packed()
Returns true if the shape is packed.
:rtype: bool
.. py:method:: transposed()
Returns true if the shape is transposed.
:rtype: bool
.. py:method:: broadcasted()
Returns true if the shape is broadcasted.
:rtype: bool
.. py:method:: standard()
Returns true if the shape is a standard shape. That is, the shape is both packed and not transposed.
:rtype: bool
.. py:method:: scalar()
Returns true if all strides are equal to 0 (scalar tensor).
:rtype: bool
argument
--------
.. py:class:: argument(data)
Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: get_shape()
Returns the shape of the argument.
:rtype: shape
.. py:method:: tolist()
Convert the elements of the argument to a python list.
:rtype: list
.. py:function:: generate_argument(s, seed=0)
Generate an argument with random data.
:param shape s: Shape of argument to generate.
:param int seed: The seed used for random number generation
:rtype: argument
target
--------
.. py:class:: target()
This represents the compiliation target.
.. py:function:: get_target(name)
Constructs the target.
:param str name: The name of the target to construct. This can either be 'cpu' or 'gpu'.
:rtype: target
program
-------
.. py:class:: program()
Represents the computation graph to compiled and run.
.. py:method:: clone()
Make a copy of the program
:rtype: program
.. py:method:: get_parameter_shapes()
Get the shapes of all the input parameters in the program.
:rtype: dict[str, shape]
.. py:method:: get_shape()
Get the shape of the final output of the program.
:rtype: shape
.. py:method:: compile(t, offload_copy=True)
Compiles the program for the target and optimizes it.
:param target t: This is the target to compile the program for.
:param bool offload_copy: For targets with offloaded memory(such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory.
.. py:method:: run(params)
Run the program.
:param params: This is a map of the input parameters which will be used when running the program.
:type params: dict[str, argument]
:return: The result of the last instruction.
:rtype: argument
.. py:function:: quantize_fp16(prog, ins_names=["all"])
Quantize the program to use fp16.
:param program prog: Program to quantize.
:param ins_names: List of instructions to quantize.
:type ins_names: list[str]
.. py:function:: quantize_int8(prog, t, calibration=[], ins_names=["dot", "convolution"])
Quantize the program to use int8.
:param program prog: Program to quantize.
:param target t: Target that will be used to run the calibration data.
:param calibration: Calibration data used to decide the parameters to the int8 optimization.
:type calibration: list[dict[str, argument]]
:param ins_names: List of instructions to quantize.
:type ins_names: list[str]
parse_onnx
----------
.. py:function:: parse_onnx(filename, batch_size=1)
Load and parse an onnx file.
:param str filename: Path to file.
:param str batch_size: default batch size to use (if not specified in onnx file).
:rtype: program
parse_tf
----------
.. py:function:: parse_tf(filename, is_nhwc=True, batch_size=1)
Load and parse an tensorflow protobuf file file.
:param str filename: Path to file.
:param bool is_nhwc: Use nhwc as default format.
:param str batch_size: default batch size to use (if not specified in protobuf).
:rtype: program
Quantization
============
quantize_fp16
-------------
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::quantize_fp16
quantize_int8
-------------
.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::quantize_int8
...@@ -62,9 +62,9 @@ struct loader ...@@ -62,9 +62,9 @@ struct loader
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
p = parse_onnx(file); p = parse_onnx(file, onnx_options{batch});
else if(file_type == "tf") else if(file_type == "tf")
p = parse_tf(file, is_nhwc); p = parse_tf(file, tf_options{is_nhwc, batch});
} }
else else
{ {
......
...@@ -7,8 +7,14 @@ ...@@ -7,8 +7,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
unsigned int batch_size = 1;
};
/// Create a program from an onnx file /// Create a program from an onnx file
program parse_onnx(const std::string& name); program parse_onnx(const std::string& name, onnx_options = onnx_options{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -7,8 +7,15 @@ ...@@ -7,8 +7,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in tf options to parser
struct tf_options
{
bool is_nhwc = false;
unsigned int batch_size = 1;
};
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, bool is_nhwc); program parse_tf(const std::string& name, tf_options = tf_options{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -28,8 +28,9 @@ struct onnx_parser ...@@ -28,8 +28,9 @@ struct onnx_parser
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>; std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
bool is_pytorch = false; bool is_pytorch = false;
unsigned int batch_size = 1;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs; std::unordered_map<std::string, operation> map_actv_funcs;
...@@ -1438,7 +1439,7 @@ struct onnx_parser ...@@ -1438,7 +1439,7 @@ struct onnx_parser
if(!contains(instructions, name)) if(!contains(instructions, name))
{ {
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type(), batch_size);
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
} }
...@@ -1658,7 +1659,7 @@ struct onnx_parser ...@@ -1658,7 +1659,7 @@ struct onnx_parser
return literal{{shape_type, dims}, data.begin(), data.end()}; return literal{{shape_type, dims}, data.begin(), data.end()};
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
...@@ -1686,13 +1687,14 @@ struct onnx_parser ...@@ -1686,13 +1687,14 @@ struct onnx_parser
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dims),
[](auto&& d) -> std::size_t { [&](auto&& d) -> std::size_t {
if(not d.has_dim_value()) if(d.has_dim_value())
{ {
long default_batch_size = 1; // FIXME if(static_cast<int>(d.dim_value()) <= 0)
return default_batch_size; return batch_size;
return d.dim_value();
} }
return d.dim_value(); return batch_size;
}); });
return {shape_type, dims}; return {shape_type, dims};
} }
...@@ -1728,10 +1730,11 @@ struct onnx_parser ...@@ -1728,10 +1730,11 @@ struct onnx_parser
} }
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name, onnx_options options)
{ {
std::fstream input(name.c_str(), std::ios::in | std::ios::binary); std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
onnx_parser parser; onnx_parser parser;
parser.batch_size = options.batch_size;
#ifndef NDEBUG #ifndef NDEBUG
// Log the program when it can't be parsed // Log the program when it can't be parsed
try try
......
...@@ -173,11 +173,20 @@ PYBIND11_MODULE(migraphx, m) ...@@ -173,11 +173,20 @@ PYBIND11_MODULE(migraphx, m)
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, [](const std::string& filename, bool is_nhwc, unsigned int batch_size) {
return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size});
},
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true,
m.def("parse_onnx", &migraphx::parse_onnx); py::arg("batch_size") = 1);
m.def("parse_onnx",
[](const std::string& filename, unsigned int batch_size) {
return migraphx::parse_onnx(filename, migraphx::onnx_options{batch_size});
},
"Parse onnx file",
py::arg("filename"),
py::arg("batch_size") = 1);
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
......
...@@ -32,8 +32,9 @@ struct tf_parser ...@@ -32,8 +32,9 @@ struct tf_parser
node_map nodes; node_map nodes;
std::vector<tensorflow::NodeDef> input_nodes; std::vector<tensorflow::NodeDef> input_nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
bool is_nhwc = true; bool is_nhwc = true;
unsigned int batch_size = 1;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
...@@ -189,6 +190,8 @@ struct tf_parser ...@@ -189,6 +190,8 @@ struct tf_parser
add_binary_op("SquaredDifference", op::sqdiff{}); add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false); add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false); add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
...@@ -208,6 +211,7 @@ struct tf_parser ...@@ -208,6 +211,7 @@ struct tf_parser
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Shape", &tf_parser::parse_shape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Split", &tf_parser::parse_split, false); add_mem_op("Split", &tf_parser::parse_split, false);
add_mem_op("SplitV", &tf_parser::parse_split, false); add_mem_op("SplitV", &tf_parser::parse_split, false);
...@@ -323,6 +327,16 @@ struct tf_parser ...@@ -323,6 +327,16 @@ struct tf_parser
transpose); transpose);
} }
template <class Op>
instruction_ref
parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = prog.add_instruction(Op{axis}, args.front());
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
instruction_ref instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -768,17 +782,17 @@ struct tf_parser ...@@ -768,17 +782,17 @@ struct tf_parser
return prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op, make_contiguous(args[0]));
} }
void parse_from(std::istream& is) // Use a literal instruction to replace the shape since output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
tensorflow::GraphDef graph; std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
if(graph.ParseFromIstream(&is)) std::vector<int32_t> vec_shape(arg_shape.size());
{ migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
this->parse_graph(graph); std::transform(
} arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
else return prog.add_literal(migraphx::literal{s, vec_shape});
{
throw std::runtime_error("Failed reading tf file");
}
} }
instruction_ref instruction_ref
...@@ -1006,6 +1020,9 @@ struct tf_parser ...@@ -1006,6 +1020,9 @@ struct tf_parser
{ {
reorder_data(dims); reorder_data(dims);
} }
std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
return static_cast<int>(dim) <= 0 ? batch_size : dim;
});
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = to_nhwc(prog.add_parameter(name, s)); instructions[name] = to_nhwc(prog.add_parameter(name, s));
} }
...@@ -1072,6 +1089,19 @@ struct tf_parser ...@@ -1072,6 +1089,19 @@ struct tf_parser
} }
} }
void parse_from(std::istream& is)
{
tensorflow::GraphDef graph;
if(graph.ParseFromIstream(&is))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading tf file");
}
}
static attribute_map get_attributes(const tensorflow::NodeDef& node) static attribute_map get_attributes(const tensorflow::NodeDef& node)
{ {
attribute_map result; attribute_map result;
...@@ -1343,11 +1373,12 @@ struct tf_parser ...@@ -1343,11 +1373,12 @@ struct tf_parser
} }
}; };
program parse_tf(const std::string& name, bool is_nhwc) program parse_tf(const std::string& name, tf_options options)
{ {
std::fstream input(name.c_str(), std::ios::in | std::ios::binary); std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf_parser parser; tf_parser parser;
parser.is_nhwc = is_nhwc; parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
#ifndef NDEBUG #ifndef NDEBUG
// Log the program when it can't be parsed // Log the program when it can't be parsed
......
...@@ -1423,23 +1423,6 @@ def sub_scalar_test(): ...@@ -1423,23 +1423,6 @@ def sub_scalar_test():
return ([arg_const, node], [arg_node], [arg_out]) return ([arg_const, node], [arg_node], [arg_out])
@onnx_test
def sum_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3])
node = onnx.helper.make_node(
'Sum',
inputs=['0', '1', '2'],
outputs=['3'],
)
return ([node], [a, b, c], [y])
@onnx_test @onnx_test
def sum_test(): def sum_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -1533,7 +1516,9 @@ def transpose_gather_test(): ...@@ -1533,7 +1516,9 @@ def transpose_gather_test():
def unknown_test(): def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5]) a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5])
node = onnx.helper.make_node('Unknown', inputs=['0', '1'], outputs=['2']) node = onnx.helper.make_node('Unknown', inputs=['0', '1'], outputs=['2'])
...@@ -1541,3 +1526,26 @@ def unknown_test(): ...@@ -1541,3 +1526,26 @@ def unknown_test():
node2 = onnx.helper.make_node('Unknown', inputs=['2'], outputs=['3']) node2 = onnx.helper.make_node('Unknown', inputs=['2'], outputs=['3'])
return ([node, node2], [x, y], [a]) return ([node, node2], [x, y], [a])
@onnx_test
def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[None, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[None, 3, 16, 16])
node = onnx.helper.make_node('Identity', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test
def variable_batch_leq_zero_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [0, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [-1, 3, 16, 16])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [-1, 3, 16, 16])
node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2'])
return ([node], [x, y], [z])
...@@ -1120,4 +1120,25 @@ TEST_CASE(unknown_test) ...@@ -1120,4 +1120,25 @@ TEST_CASE(unknown_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(variable_batch_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_onnx("variable_batch_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(variable_batch_leq_zero_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_onnx("variable_batch_leq_zero_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment