Commit a9815bf4 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_ref_multibroadcast' of...

Merge branch 'dyn_ref_multibroadcast' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_transpose
parents b80e2db1 2fa68ded
...@@ -81,6 +81,14 @@ void throws_shape(const migraphx::shape&, Ts...) ...@@ -81,6 +81,14 @@ void throws_shape(const migraphx::shape&, Ts...)
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
} }
TEST_CASE(binary_dyn_static_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("add"), a_shape, b_shape);
}
TEST_CASE(broadcast) TEST_CASE(broadcast)
{ {
{ {
...@@ -1207,6 +1215,21 @@ TEST_CASE(multibroadcast_2in_static_dyn1) ...@@ -1207,6 +1215,21 @@ TEST_CASE(multibroadcast_2in_static_dyn1)
a_shape); a_shape);
} }
TEST_CASE(multibroadcast_2in_static_dyn2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_dyn_error0) TEST_CASE(multibroadcast_2in_static_dyn_error0)
{ {
// doesn't match on first dimension // doesn't match on first dimension
...@@ -1253,6 +1276,22 @@ TEST_CASE(multibroadcast_2in_dyn_dyn0) ...@@ -1253,6 +1276,22 @@ TEST_CASE(multibroadcast_2in_dyn_dyn0)
a_shape); a_shape);
} }
TEST_CASE(multibroadcast_2in_dyn_dyn1)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_dyn_dyn_error0) TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
{ {
// max doesn't match on second dimension of a // max doesn't match on second dimension of a
......
...@@ -38,6 +38,27 @@ TEST_CASE(test_shape_default) ...@@ -38,6 +38,27 @@ TEST_CASE(test_shape_default)
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
TEST_CASE(test_dyn_4arg_constructor)
{
migraphx::shape s{migraphx::shape::float_type,
{
1,
4,
4,
},
{
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
}
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
{ {
migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}}; migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_concat_broadcast_add : verify_program<test_concat_broadcast_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4}};
migraphx::shape s1{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape s2{migraphx::shape::float_type, {6, 1}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s0);
auto z = mm->add_parameter("z", s0);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
auto b = mm->add_literal(migraphx::generate_literal(s2, 15));
auto bb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), b);
mm->add_instruction(migraphx::make_op("add"), concat, bb);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice_concat_add : verify_program<test_slice_concat_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 24, 2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {1, 8, 2, 2}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s0);
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), x);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), slice, y, y);
mm->add_instruction(migraphx::make_op("add"), concat, z);
return p;
}
};
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import argparse
import onnx
from onnx import version_converter
def parse_args():
parser = argparse.ArgumentParser(
description=
'MIGraphX Onnx Model Convertion. Use to convert the opset of the input model to MIGraphX\'s'
)
req_args = parser.add_argument_group(title='required arguments')
req_args.add_argument('--model',
type=str,
required=True,
help='path to onnx file')
req_args.add_argument('--output',
type=str,
required=True,
help='path to output onnx file')
req_args.add_argument('--opset',
type=int,
required=True,
help='The output opset')
req_args.add_argument('--infer_shapes',
action='store_true',
help='Infer shapes for output model')
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_path = args.model
out_model_path = args.output
target_opset = args.opset
verbose = args.verbose
infer_shapes = args.infer_shapes
original_model = onnx.load(model_path)
if verbose:
print(f"The model before conversion:\n{original_model}")
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(
original_model, target_opset)
if infer_shapes:
converted_model = onnx.shape_inference.infer_shapes(converted_model)
if verbose:
print(f"The model after conversion:\n{converted_model}")
# Save the ONNX model
onnx.save(converted_model, out_model_path)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment