Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f25606f9
Unverified
Commit
f25606f9
authored
Oct 17, 2023
by
Charlie Lin
Committed by
GitHub
Oct 17, 2023
Browse files
2 Input Reshape `ref` implementation (#2304)
parent
a7200610
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
262 additions
and
48 deletions
+262
-48
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+49
-9
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+16
-6
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+18
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+63
-31
test/onnx/reshape_variable_input_dyn_test.onnx
test/onnx/reshape_variable_input_dyn_test.onnx
+0
-0
test/onnx/reshape_variable_input_test.onnx
test/onnx/reshape_variable_input_test.onnx
+17
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+22
-1
test/ref/reshape.cpp
test/ref/reshape.cpp
+77
-1
No files found.
src/include/migraphx/op/reshape.hpp
View file @
f25606f9
...
@@ -36,6 +36,22 @@ namespace migraphx {
...
@@ -36,6 +36,22 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct
reshape
struct
reshape
{
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
...
@@ -215,32 +231,56 @@ struct reshape
...
@@ -215,32 +231,56 @@ struct reshape
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"reshape: Dimensions for reshape can only have one -1 dim"
);
MIGRAPHX_THROW
(
"reshape: Dimensions for reshape can only have one -1 dim"
);
auto
s0
=
inputs
.
front
();
auto
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
()
)
if
(
inputs
.
size
()
==
1
)
{
{
return
dyn_compute_shape
(
s0
);
if
(
s0
.
dynamic
())
{
return
dyn_compute_shape
(
s0
);
}
else
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
}
}
}
else
else
{
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
return
inputs
.
back
(
);
}
}
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
assert
(
dyn_out
.
computed_shape
.
standard
());
assert
(
dyn_out
.
computed_shape
.
standard
());
argument
result
{
dyn_out
.
computed_shape
};
if
(
args
.
size
()
==
1
)
{
argument
result
{
dyn_out
.
computed_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
return
result
;
return
result
;
}
else
{
// 2 arg
if
(
args
[
0
].
get_shape
().
elements
()
!=
args
[
1
].
get_shape
().
elements
())
{
MIGRAPHX_THROW
(
"Reshape: Number of elements must match at runtime. Input: "
+
std
::
to_string
(
args
[
0
].
get_shape
().
elements
())
+
" Output buffer: "
+
std
::
to_string
(
args
[
1
].
get_shape
().
elements
()));
}
visit_all
(
args
[
1
],
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
return
args
[
1
];
}
}
}
};
};
...
...
src/onnx/parse_reshape.cpp
View file @
f25606f9
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -45,15 +45,25 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -45,15 +45,25 @@ struct parse_reshape : op_parser<parse_reshape>
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"shape"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
}
if
(
args
.
size
()
==
2
)
else
{
{
// 2 inputs
auto
s
=
args
[
1
]
->
eval
();
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape: non-constant shape input is not supported"
);
if
(
s
.
empty
())
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
{
// arg[1] not eval-able
auto
alloc_ins
=
info
.
add_instruction
(
make_op
(
"allocate"
,
{{
"buf_type"
,
args
[
0
]
->
get_shape
().
type
()}}),
args
[
1
]);
return
info
.
add_instruction
(
make_op
(
"reshape"
),
args
[
0
],
alloc_ins
);
}
else
{
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
}
}
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
}
};
};
...
...
test/onnx/gen_onnx.py
View file @
f25606f9
...
@@ -6065,6 +6065,24 @@ def reshape_non_standard_test():
...
@@ -6065,6 +6065,24 @@ def reshape_non_standard_test():
return
([
trans
,
res
],
[
x
],
[
y
])
return
([
trans
,
res
],
[
x
],
[
y
])
@
onnx_test
()
def
reshape_variable_input_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
4
,
2
,
3
])
x_shape
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT64
,
[
2
])
y
=
helper
.
make_tensor_value_info
(
'2'
,
TensorProto
.
FLOAT
,
[
3
,
8
])
node
=
onnx
.
helper
.
make_node
(
'Reshape'
,
inputs
=
[
'0'
,
'1'
],
outputs
=
[
'2'
])
return
([
node
],
[
x
,
x_shape
],
[
y
])
@
onnx_test
()
def
reshape_variable_input_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
3
])
x_shape
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT64
,
[
2
])
y
=
helper
.
make_tensor_value_info
(
'2'
,
TensorProto
.
FLOAT
,
[
None
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Reshape'
,
inputs
=
[
'0'
,
'1'
],
outputs
=
[
'2'
])
return
([
node
],
[
x
,
x_shape
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
resize_downsample_f_test
():
def
resize_downsample_f_test
():
scales
=
np
.
array
([
1.0
,
1.0
,
0.6
,
0.6
],
dtype
=
np
.
float32
)
scales
=
np
.
array
([
1.0
,
1.0
,
0.6
,
0.6
],
dtype
=
np
.
float32
)
...
...
test/onnx/onnx_test.cpp
View file @
f25606f9
...
@@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test)
...
@@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test)
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{{"mode", migraphx::op::pooling_mode::average},
{
"padding"
,
{
2
,
2
,
2
,
2
}},
{"padding", {2, 2, 2, 2}},
{
"stride"
,
{
2
,
2
}},
{"stride", {2, 2}},
{
"lengths"
,
{
6
,
6
}}}),
{"lengths", {6, 6}}}),
input);
input);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins);
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins);
...
@@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test)
...
@@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
auto ret
= mm->add_instruction(migraphx::make_op("pooling",
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{{"mode", migraphx::op::pooling_mode::average},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{"padding", {0, 0, 0, 0}},
{
"stride"
,
{
2
,
2
}},
{"stride", {2, 2}},
{
"lengths"
,
{
6
,
6
}}}),
{"lengths", {6, 6}}}),
ins_pad);
ins_pad);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test)
...
@@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
auto ret
= mm->add_instruction(migraphx::make_op("pooling",
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{{"mode", migraphx::op::pooling_mode::average},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{"padding", {0, 0, 0, 0}},
{
"stride"
,
{
1
,
1
}},
{"stride", {1, 1}},
{
"lengths"
,
{
2
,
2
}}}),
{"lengths", {2, 2}}}),
ins_pad);
ins_pad);
mm->add_return({ret});
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
...
@@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test)
...
@@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test)
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{{"mode", migraphx::op::pooling_mode::average},
{
"padding"
,
{
1
,
1
,
1
,
1
}},
{"padding", {1, 1, 1, 1}},
{
"stride"
,
{
1
,
1
}},
{"stride", {1, 1}},
{
"lengths"
,
{
2
,
2
}}}),
{"lengths", {2, 2}}}),
input);
input);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins);
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins);
...
@@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test)
...
@@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = mm->add_instruction(
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
migraphx::make_op("convolution_backwards",
{{
"padding"
,
{
0
}},
{
"stride"
,
{
2
}},
{
"dilation"
,
{
1
}}}),
{{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}),
l0,
l0,
l1);
l1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}),
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}),
...
@@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test)
...
@@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction(
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
migraphx::make_op("convolution_backwards",
{{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
3
,
2
,
2
}},
{
"dilation"
,
{
1
,
1
,
1
}}}),
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0,
l0,
l1);
l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
...
@@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test)
...
@@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction(
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
migraphx::make_op("convolution_backwards",
{{
"padding"
,
{
0
,
0
,
0
}},
{
"stride"
,
{
3
,
2
,
2
}},
{
"dilation"
,
{
1
,
1
,
1
}}}),
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0,
l0,
l1);
l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
...
@@ -1996,7 +1996,7 @@ TEST_CASE(equal_test)
...
@@ -1996,7 +1996,7 @@ TEST_CASE(equal_test)
auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2);
auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq);
eq);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test)
...
@@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
input1);
auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2);
auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -2726,7 +2726,7 @@ TEST_CASE(greater_test)
...
@@ -2726,7 +2726,7 @@ TEST_CASE(greater_test)
auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
gr);
gr);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test)
...
@@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
input1);
auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2);
auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -3602,7 +3602,7 @@ TEST_CASE(less_test)
...
@@ -3602,7 +3602,7 @@ TEST_CASE(less_test)
auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto ret = mm->add_instruction(
auto ret = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le);
le);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test)
...
@@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
bool_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
input1);
auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2);
auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2);
mm->add_return({ret});
mm->add_return({ret});
...
@@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test)
...
@@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test)
// a shape with 4 dynamic dimensions
// a shape with 4 dynamic dimensions
auto l0 = mm->add_parameter("x",
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{{
3
,
3
},
{
3
,
5
},
{
4
,
6
,
{
5
}},
{
5
,
7
,
{
6
}}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
...
@@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test)
...
@@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test)
// No axes given in the onnx file. Parser should default to all axes.
// No axes given in the onnx file. Parser should default to all axes.
auto l0 = mm->add_parameter("x",
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{{
3
,
3
},
{
3
,
5
},
{
4
,
6
,
{
5
}},
{
5
,
7
,
{
6
}}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
...
@@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test)
...
@@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test)
EXPECT(p == prog);
EXPECT(p == prog);
}
}
TEST_CASE(reshape_variable_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
auto prog = optimize_onnx("reshape_variable_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reshape_variable_input_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {3, 3}}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
mm->add_return({reshape});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_c_test)
TEST_CASE(resize_downsample_c_test)
{
{
migraphx::program p;
migraphx::program p;
...
@@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
...
@@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter("0",
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{{
1
,
1
},
{
1
,
4
},
{
1
,
1
},
{
1
,
1
},
{
1
,
4
},
{
1
,
1
}}});
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
...
@@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test)
...
@@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test)
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = mm->add_instruction(
auto cin0 = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
uint32_type
)}}),
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
input0);
input0);
auto cin1 = mm->add_instruction(
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
...
...
test/onnx/reshape_variable_input_dyn_test.onnx
0 → 100644
View file @
f25606f9
File added
test/onnx/reshape_variable_input_test.onnx
0 → 100644
View file @
f25606f9
reshape_variable_input_test:p
0
12"Reshapereshape_variable_input_testZ
0
Z
1
b
2
B
\ No newline at end of file
test/op_shape_test.cpp
View file @
f25606f9
...
@@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
...
@@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_dyn_
shape
)
TEST_CASE
(
reshape_dyn_
1in
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{
for
(
auto
&&
new_shape
:
std
::
vector
<
std
::
vector
<
int64_t
>>
{
...
@@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape)
...
@@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape)
}
}
}
}
TEST_CASE
(
reshape_dyn_2in_0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
8
,
8
},
{
3
,
3
},
{
1
,
1
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
),
input
,
output
);
}
TEST_CASE
(
reshape_dyn_2in_1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
12
,
12
},
{
2
,
2
},
{
1
,
1
},
{
1
,
4
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
),
input
,
output
);
}
TEST_CASE
(
reshape_dyn_2in_2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{{
1
,
2
},
{
6
,
12
},
{
1
,
1
},
{
4
,
4
}}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
),
input
,
output
);
}
TEST_CASE
(
reshape_multiple_non_fixed_error
)
TEST_CASE
(
reshape_multiple_non_fixed_error
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
10
,
20
},
{
1
,
1
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
10
,
20
},
{
1
,
1
}}};
...
...
test/ref/reshape.cpp
View file @
f25606f9
...
@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2)
...
@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
reshape_dyn_test
)
TEST_CASE
(
reshape_dyn_
1in_
test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
...
@@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test)
...
@@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
reshape_2in_test0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
6
,
6
},
{
4
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
input_fixed_shape
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
EXPECT
(
result
.
get_shape
()
==
output_fixed_shape
);
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_2in_test1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
},
{
6
,
6
},
{
2
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
s_in
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
EXPECT
(
result
.
get_shape
()
==
output_fixed_shape
);
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_2in_elements_runtime_error
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
},
{
6
,
6
},
{
2
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
// elements do not match up
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
2
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
s_in
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
EXPECT
(
test
::
throws
([
&
]
{
std
::
ignore
=
p
.
eval
(
params
).
back
();
}));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment