Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b119ed8f
Commit
b119ed8f
authored
Aug 24, 2023
by
Alan Turner
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into develop
parents
26d1a969
6f1c947f
Changes
56
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
403 additions
and
107 deletions
+403
-107
test/onnx/slice_var_input_dyn1.onnx
test/onnx/slice_var_input_dyn1.onnx
+0
-0
test/onnx/slice_var_input_static0.onnx
test/onnx/slice_var_input_static0.onnx
+0
-0
test/onnx/slice_var_input_static1.onnx
test/onnx/slice_var_input_static1.onnx
+26
-0
test/onnx/slice_var_input_steps_error.onnx
test/onnx/slice_var_input_steps_error.onnx
+29
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+0
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+114
-38
test/pad_calc_test.cpp
test/pad_calc_test.cpp
+0
-1
test/quantization.cpp
test/quantization.cpp
+0
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+109
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+35
-33
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+6
-7
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+10
-21
test/verify/gemm_literal.cpp
test/verify/gemm_literal.cpp
+2
-2
tools/build_and_test_onnxrt.sh
tools/build_and_test_onnxrt.sh
+1
-1
tools/docker/sles.docker
tools/docker/sles.docker
+46
-0
tools/install_prereqs.sh
tools/install_prereqs.sh
+25
-2
No files found.
test/onnx/slice_var_input_dyn1.onnx
0 → 100644
View file @
b119ed8f
File added
test/onnx/slice_var_input_static0.onnx
0 → 100644
View file @
b119ed8f
File added
test/onnx/slice_var_input_static1.onnx
0 → 100644
View file @
b119ed8f
slice_var_input_static1:
)
data
starts
ends
axesoutput"Sliceslice_var_input_static1Z
data
Z
starts
Z
ends
Z
axes
b
output
B
\ No newline at end of file
test/onnx/slice_var_input_steps_error.onnx
0 → 100644
View file @
b119ed8f
slice_var_input_steps_error:
0arg_step"Constant*
value**Bstep
3
data
starts
ends
axes
arg_stepoutput"Sliceslice_var_input_steps_errorZ
data
Z
starts
Z
ends
Z
axes
b
output
B
\ No newline at end of file
test/onnx/verify_onnx.cpp
View file @
b119ed8f
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
...
...
test/op_shape_test.cpp
View file @
b119ed8f
...
@@ -24,7 +24,8 @@
...
@@ -24,7 +24,8 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream>
#include <sstream>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
...
@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
...
@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{
{
std
::
vector
<
std
::
size_t
>
lens
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
lens
{
1
,
1
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
}};
throws_shape
(
migraphx
::
op
::
broadcast
{
1
,
lens
},
input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}
})
,
input
);
}
}
{
{
std
::
vector
<
std
::
size_t
>
lens
{
2
,
2
};
std
::
vector
<
std
::
size_t
>
lens
{
2
,
2
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
1
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
1
,
2
}};
throws_shape
(
migraphx
::
op
::
broadcast
{
1
,
lens
},
input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}
})
,
input
);
}
}
{
{
...
@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
...
@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input
);
input
);
}
}
template
<
class
T
>
void
test_softmax_variations
(
const
std
::
string
&
name
)
void
test_softmax_variations
()
{
{
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
0
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
0
}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
1
}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
2
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
2
}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
3
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
3
}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
4
;
int
axis
=
4
;
throws_shape
(
T
{
axis
},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axis"
,
axis
}
})
,
input
);
}
}
}
}
TEST_CASE
(
logsoftmax
)
{
test_softmax_variations
<
migraphx
::
op
::
logsoftmax
>
();
}
TEST_CASE
(
logsoftmax
)
{
test_softmax_variations
(
"logsoftmax"
);
}
TEST_CASE
(
softmax
)
{
test_softmax_variations
(
"softmax"
);
}
TEST_CASE
(
lstm
)
TEST_CASE
(
lstm
)
{
{
...
@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
...
@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape
(
migraphx
::
make_op
(
"dequantizelinear"
),
input
,
scales
,
zeros
);
throws_shape
(
migraphx
::
make_op
(
"dequantizelinear"
),
input
,
scales
,
zeros
);
}
}
template
<
class
T
>
void
test_reduce_ops
(
const
std
::
string
&
name
)
void
test_reduce_ops
()
{
{
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
T
{},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
migraphx
::
make_op
(
name
),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
T
{{
0
,
1
,
2
,
3
}},
input
);
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
1
}},
T
{{
2
,
3
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
2
,
3
}}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
T
{{
0
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
}}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1
}},
T
{{
-
1
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
-
1
}}}),
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
throws_shape
(
T
{
{
4
}},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
4
}}
})
,
input
);
}
}
}
}
// dynamic shape
// dynamic shape
template
<
class
T
>
void
test_dyn_reduce_ops
(
const
std
::
string
&
name
)
void
test_dyn_reduce_ops
()
{
{
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
2
,
3
,
{
3
}},
{
1
,
1
}})},
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
2
,
3
,
{
3
}},
{
1
,
1
}})},
T
{
{
-
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
-
1
}}
})
,
input
);
input
);
}
}
{
{
...
@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
...
@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
2
,
4
,
{
4
}}})},
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
2
,
4
,
{
4
}}})},
T
{
{
0
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
}}
})
,
input
);
input
);
}
}
{
{
...
@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
...
@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
1
,
1
}})},
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
1
,
1
}})},
T
{{}}
,
migraphx
::
make_op
(
name
)
,
input
);
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
throws_shape
(
T
{
{
4
}},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
4
}}
})
,
input
);
}
}
}
}
TEST_CASE
(
reduce_max
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_max
>
(
);
}
TEST_CASE
(
reduce_max
)
{
test_reduce_ops
(
"
reduce_max
"
);
}
TEST_CASE
(
reduce_mean
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_mean
>
(
);
}
TEST_CASE
(
reduce_mean
)
{
test_reduce_ops
(
"
reduce_mean
"
);
}
TEST_CASE
(
reduce_prod
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_prod
>
(
);
}
TEST_CASE
(
reduce_prod
)
{
test_reduce_ops
(
"
reduce_prod
"
);
}
TEST_CASE
(
reduce_sum
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_sum
>
(
);
}
TEST_CASE
(
reduce_sum
)
{
test_reduce_ops
(
"
reduce_sum
"
);
}
TEST_CASE
(
reduce_max_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_max
>
(
);
}
TEST_CASE
(
reduce_max_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_max
"
);
}
TEST_CASE
(
reduce_mean_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_mean
>
(
);
}
TEST_CASE
(
reduce_mean_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_mean
"
);
}
TEST_CASE
(
reduce_prod_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_prod
>
(
);
}
TEST_CASE
(
reduce_prod_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_prod
"
);
}
TEST_CASE
(
reduce_sum_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_sum
>
(
);
}
TEST_CASE
(
reduce_sum_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_sum
"
);
}
TEST_CASE
(
reshape_shape
)
TEST_CASE
(
reshape_shape
)
{
{
...
@@ -2822,7 +2839,7 @@ TEST_CASE(select_module_dyn)
...
@@ -2822,7 +2839,7 @@ TEST_CASE(select_module_dyn)
input
);
input
);
}
}
TEST_CASE
(
slice_shape
)
TEST_CASE
(
slice_
static_
shape
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}},
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}},
...
@@ -2840,6 +2857,67 @@ TEST_CASE(slice_shape)
...
@@ -2840,6 +2857,67 @@ TEST_CASE(slice_shape)
input
);
input
);
}
}
TEST_CASE
(
slice_var_inputs_static_shape0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
2
}}}),
input
,
starts
,
ends
);
}
TEST_CASE
(
slice_var_inputs_static_shape1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
),
input
,
starts
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_error0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
3
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
),
input
,
starts
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
2
}}}),
input
,
starts
,
ends
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
6
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
),
input
,
starts
,
ends
,
axes
);
}
TEST_CASE
(
slice_dyn_shape0
)
TEST_CASE
(
slice_dyn_shape0
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
7
},
{
2
,
3
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
7
},
{
2
,
3
}}};
...
@@ -2870,7 +2948,7 @@ TEST_CASE(slice_dyn_shape2)
...
@@ -2870,7 +2948,7 @@ TEST_CASE(slice_dyn_shape2)
TEST_CASE
(
slice_dyn_shape3
)
TEST_CASE
(
slice_dyn_shape3
)
{
{
// TODO: When
variable
dimension slicing is allowed, Slice to a size smaller than min.
// TODO: When
non-fixed
dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error.
// Until then, this action is an error.
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
8
},
{
2
,
3
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
8
},
{
2
,
3
}}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
...
@@ -2901,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
...
@@ -2901,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
input
);
input
);
}
}
TEST_CASE
(
softmax
)
{
test_softmax_variations
<
migraphx
::
op
::
softmax
>
();
}
TEST_CASE
(
softmax_dyn0
)
TEST_CASE
(
softmax_dyn0
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
4
,
4
},
{
5
,
5
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
4
,
4
},
{
5
,
5
}}};
...
...
test/pad_calc_test.cpp
View file @
b119ed8f
...
@@ -22,7 +22,6 @@
...
@@ -22,7 +22,6 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/pad_calc.hpp>
#include "test.hpp"
#include "test.hpp"
...
...
test/quantization.cpp
View file @
b119ed8f
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
...
...
test/ref_ops_test.cpp
View file @
b119ed8f
...
@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test)
...
@@ -8153,6 +8153,115 @@ TEST_CASE(slice_test)
}
}
}
}
TEST_CASE(slice_var_inputs_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {1};
std::vector<int32_t> end_data = {3};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {-2};
std::vector<int32_t> end_data = {2831};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int64_type, {3}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> start_data = {0, 0, 0};
std::vector<int64_t> end_data = {2, 2, 2};
std::vector<int64_t> axes_data = {0, 1, 2};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
std::vector<int> end_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_dyn_test0)
TEST_CASE(slice_dyn_test0)
{
{
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
...
...
test/simplify_algebra_test.cpp
View file @
b119ed8f
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/
o
per
ators
.hpp>
#include <migraphx/per
mutation
.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
...
@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{
{
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
op
::
broadcast
b
{
1
,
{
1
,
2
,
3
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
2
,
3
,
3
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
outer
);
auto
x
=
m1
.
add_parameter
(
"x"
,
outer
);
...
@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
...
@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{
{
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
op
::
broadcast
b
{
1
,
{
1
,
2
,
3
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
2
,
3
,
3
}}
})
;
auto
create_program
=
[
&
]
{
auto
create_program
=
[
&
]
{
migraphx
::
module
m
;
migraphx
::
module
m
;
auto
x
=
m
.
add_parameter
(
"x"
,
outer
);
auto
x
=
m
.
add_parameter
(
"x"
,
outer
);
...
@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
...
@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE
(
simplify_inner_broadcast1
)
TEST_CASE
(
simplify_inner_broadcast1
)
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
...
@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
...
@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE
(
simplify_inner_broadcast2
)
TEST_CASE
(
simplify_inner_broadcast2
)
{
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
...
@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
...
@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE
(
simplify_inner_broadcast_scalar
)
TEST_CASE
(
simplify_inner_broadcast_scalar
)
{
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
32
,
384
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
32
,
384
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
...
@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
...
@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
}});
auto
yb
=
m2
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
1
,
384
}},
y
);
auto
yb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
384
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
...
@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
...
@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE
(
simplify_inner_broadcast_different_dims
)
TEST_CASE
(
simplify_inner_broadcast_different_dims
)
{
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
2
,
384
,
768
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
2
,
384
,
768
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
...
@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
...
@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
768
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
768
}});
auto
yb
=
m2
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
384
,
768
}},
y
);
auto
yb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
384
,
768
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
...
@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
...
@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE
(
simplify_inner_broadcast_different_broadcasts
)
TEST_CASE
(
simplify_inner_broadcast_different_broadcasts
)
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
1
,
24
,
112
,
112
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
24
,
112
,
112
}}
})
;
auto
mb
=
migraphx
::
op
::
multibroadcast
{
{
1
,
24
,
112
,
112
}};
auto
mb
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
1
,
24
,
112
,
112
}}
})
;
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
24
}});
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
24
}});
...
@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
...
@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
auto
one
=
m1
.
add_literal
(
1
);
...
@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
...
@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
2
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
2
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
...
@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
...
@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
auto
one
=
m1
.
add_literal
(
1
);
...
@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
...
@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
2
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
2
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
...
@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
...
@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
auto
one
=
m1
.
add_literal
(
1
);
...
@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
...
@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
...
@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
...
@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
...
@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
auto
two
=
m2
.
add_literal
(
2
);
...
@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
...
@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
r
=
migraphx
::
op
::
reshape
{
{
3
,
4
}};
auto
r
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
...
@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
auto
two
=
m2
.
add_literal
(
2
);
...
@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
...
@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
2
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
2
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
r
=
migraphx
::
op
::
reshape
{
{
3
,
2
,
4
}};
auto
r
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
...
@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
...
@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
...
@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
...
@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
...
@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
...
@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
...
@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
auto
two
=
m2
.
add_literal
(
2
);
...
@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
...
@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
6
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
6
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
,
3
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
3
}},
{
"starts"
,
{
0
,
0
}},
{
"ends"
,
{
1
,
3
}}}),
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
3
}},
{
"starts"
,
{
0
,
0
}},
{
"ends"
,
{
1
,
3
}}}),
...
@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
...
@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
...
@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slice
=
m2
.
add_instruction
(
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
...
@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
...
@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slice
=
m2
.
add_instruction
(
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
test/simplify_reshapes_test.cpp
View file @
b119ed8f
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
...
@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
...
@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
}
TEST_CASE
(
concat_multibroadcasts2
)
TEST_CASE
(
concat_multibroadcasts2
)
...
@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
...
@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
0
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
0
);
}
}
TEST_CASE
(
concat_multibroadcasts3
)
TEST_CASE
(
concat_multibroadcasts3
)
...
@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
...
@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
2
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
2
);
}
}
TEST_CASE
(
concat_multibroadcasts4
)
TEST_CASE
(
concat_multibroadcasts4
)
...
@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
...
@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto
new_concat
=
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
3
);
}
}
TEST_CASE
(
concat_transpose2
)
TEST_CASE
(
concat_transpose2
)
...
@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
...
@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto
new_concat
=
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
}
TEST_CASE
(
concat_transpose3
)
TEST_CASE
(
concat_transpose3
)
...
@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
...
@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto
new_concat
=
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
}
TEST_CASE
(
concat_transpose4
)
TEST_CASE
(
concat_transpose4
)
...
...
test/tf/tf_test.cpp
View file @
b119ed8f
...
@@ -37,7 +37,6 @@
...
@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
...
@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
...
@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
1
,
0
}});
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
1
,
0
}});
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
2
,
-
1
}});
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
2
,
-
1
}});
migraphx
::
op
::
slice
op
;
mm
->
add_instruction
(
op
.
starts
=
{
1
,
0
};
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
1
,
0
}},
{
"ends"
,
{
3
,
10
}},
{
"axes"
,
{
0
,
1
}}}),
l0
);
op
.
ends
=
{
3
,
10
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
mm
->
add_instruction
(
op
,
l0
);
auto
prog
=
optimize_tf
(
"slice_test.pb"
,
false
);
auto
prog
=
optimize_tf
(
"slice_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
...
@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
1
,
1
}});
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
1
,
1
}});
auto
l1
=
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
std
::
size_t
num_axes
=
4
;
auto
l2
=
mm
->
add_instruction
(
migraphx
::
op
::
slice
op
;
migraphx
::
make_op
(
op
.
starts
=
{
0
,
0
,
0
,
0
};
"slice"
,
{{
"starts"
,
{
0
,
0
,
0
,
0
}},
{
"ends"
,
{
1
,
1
,
1
,
5
}},
{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
op
.
ends
=
{
1
,
1
,
1
,
5
};
l1
);
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
auto
l2
=
mm
->
add_instruction
(
op
,
l1
);
auto
shrink_axis
=
1
;
auto
shrink_axis
=
1
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
shrink_axis
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
shrink_axis
}}}),
l2
);
auto
prog
=
optimize_tf
(
"stridedslice_test.pb"
,
true
);
auto
prog
=
optimize_tf
(
"stridedslice_test.pb"
,
true
);
...
@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
...
@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
3
,
3
}});
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
3
,
3
}});
std
::
size_t
num_axes
=
4
;
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
0
,
1
,
1
,
0
};
op
.
ends
=
{
1
,
3
,
3
,
10
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
// add literals for starts, ends, and strides in tf (NHWC format)
// add literals for starts, ends, and strides in tf (NHWC format)
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
0
,
1
,
1
,
0
});
std
::
vector
<
int
>
{
0
,
1
,
1
,
0
});
...
@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
...
@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto
l1
=
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
auto
l2
=
mm
->
add_instruction
(
op
,
l1
);
auto
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
1
,
0
}},
{
"ends"
,
{
1
,
3
,
3
,
10
}},
{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
l1
);
auto
l3
=
auto
l3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
l2
);
mm
->
add_return
({
l3
});
mm
->
add_return
({
l3
});
...
...
test/verify/gemm_literal.cpp
View file @
b119ed8f
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/
operators
.hpp>
#include <migraphx/
make_op
.hpp>
struct
gemm_literal
:
verify_program
<
gemm_literal
>
struct
gemm_literal
:
verify_program
<
gemm_literal
>
{
{
...
@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
...
@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto
a
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
a_shape
));
auto
a
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
a_shape
));
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
dot
{}
,
a
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
dot
"
)
,
a
,
b
);
return
p
;
return
p
;
}
}
...
...
tools/build_and_test_onnxrt.sh
View file @
b119ed8f
...
@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
...
@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
# Add newer cmake to the path
# Add newer cmake to the path
export
PATH
=
"/opt/cmake/bin:
$PATH
"
export
PATH
=
"/opt/cmake/bin:
$PATH
"
export
CXXFLAGS
=
"-D__HIP_PLATFORM_AMD__=1 -w"
export
CXXFLAGS
=
"-D__HIP_PLATFORM_AMD__=1 -w"
./build.sh
--config
Release
--cmake_extra_defines
CMAKE_HIP_COMPILER
=
/opt/rocm/llvm/bin/clang++
--update
--build
--parallel
--cmake_extra_defines
ONNXRUNTIME_VERSION
=
$(
cat
./VERSION_NUMBER
)
--skip_tests
--rocm_home
/opt/rocm
--use_migraphx
--migraphx_home
/opt/rocm
--rocm_version
=
`
cat
/opt/rocm/.info/version-dev
`
--allow_running_as_root
./build.sh
--config
Release
--cmake_extra_defines
CMAKE_HIP_COMPILER
=
/opt/rocm/llvm/bin/clang++
--update
--build
--build_wheel
--parallel
--cmake_extra_defines
ONNXRUNTIME_VERSION
=
$(
cat
./VERSION_NUMBER
)
--skip_tests
--rocm_home
/opt/rocm
--use_migraphx
--migraphx_home
/opt/rocm
--rocm_version
=
`
cat
/opt/rocm/.info/version-dev
`
--allow_running_as_root
cd
build/Linux/Release
cd
build/Linux/Release
#Add test launcher for onnxrt tests
#Add test launcher for onnxrt tests
...
...
tools/docker/sles.docker
0 → 100644
View file @
b119ed8f
FROM
registry.suse.com/suse/sle15:15.4
RUN
sh
-c
'echo -e "
\
[rocm]\n
\
name=rocm\n
\
baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n
\
enabled=1\n
\
gpgcheck=1\n
\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n
\
" > /etc/zypp/repos.d/rocm.repo'
RUN
cat
/etc/zypp/repos.d/rocm.repo
RUN
zypper
-n
--gpg-auto-import-keys
refresh
RUN
zypper
install
-y
-t
pattern devel_basis enhanced_base
RUN
zypper
--gpg-auto-import-keys
install
-y
\
doxygen
\
gcc-c++
\
gdb
\
git
\
python3-pip
# Workaround broken rocm packages
RUN
ln
-s
/opt/rocm-
*
/opt/rocm
RUN
echo
"/opt/rocm/lib"
>
/etc/ld.so.conf.d/rocm.conf
RUN
echo
"/opt/rocm/llvm/lib"
>
/etc/ld.so.conf.d/rocm-llvm.conf
RUN
ldconfig
ENV
LC_ALL=C.UTF-8
ENV
LANG=C.UTF-8
# Install yapf
RUN
pip3
install
yapf
==
0.28.0
# Install doc requirements
# ADD docs/.sphinx/requirements.txt /doc-requirements.txt
# RUN pip3 install -r /doc-requirements.txt
# Install dependencies
ADD
dev-requirements.txt /dev-requirements.txt
ADD
requirements.txt /requirements.txt
ADD
rbuild.ini /rbuild.ini
COPY
./tools/install_prereqs.sh /
RUN
/install_prereqs.sh /usr/local /
&&
rm
/install_prereqs.sh
tools/install_prereqs.sh
View file @
b119ed8f
...
@@ -31,9 +31,30 @@ set -e
...
@@ -31,9 +31,30 @@ set -e
export
LC_ALL
=
C.UTF-8
export
LC_ALL
=
C.UTF-8
export
LANG
=
C.UTF-8
export
LANG
=
C.UTF-8
source
/etc/os-release
if
[[
(
"
${
ID
}
"
==
"sles"
)
]]
;
then
zypper
-n
--gpg-auto-import-keys
install
-y
\
cmake
\
miopen-hip-devel
\
openmp-extras-devel
\
python3-devel
\
python3-pip
\
rocblas-devel
\
rocm-cmake
else
# Need pip3 and Python headers to build dependencies
apt update
&&
apt
install
-y
\
cmake
\
libnuma-dev
\
miopen-hip-dev
\
openmp-extras
\
python3-dev
\
python3-pip
\
rocblas-dev
\
rocm-cmake
fi
# Need pip3 and Python headers to build dependencies
apt update
&&
apt
install
-y
python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
# Needed for cmake to build various pip packages
pip3
install
setuptools wheel
pip3
install
setuptools wheel
...
@@ -56,9 +77,11 @@ echo "Dependencies are installed at $PREFIX"
...
@@ -56,9 +77,11 @@ echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild
# Install deps with rbuild
rbuild prepare
-d
$PREFIX
-s
develop
rbuild prepare
-d
$PREFIX
-s
develop
if
[[
(
"
${
ID
}
"
!=
"sles"
)
]]
;
then
export
CMAKE_ARGS
=
"-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
export
CMAKE_ARGS
=
"-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3
install
onnx
==
1.10.2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
pip3
install
onnx
==
1.10.2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
pip3
install
protobuf
==
3.20.0
pip3
install
protobuf
==
3.20.0
fi
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment