Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e6290061
Unverified
Commit
e6290061
authored
Aug 18, 2023
by
Paul Fultz II
Committed by
GitHub
Aug 18, 2023
Browse files
Remove operators.hpp includes (#2086)
parent
e4ef64f4
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
111 additions
and
98 deletions
+111
-98
src/memory_coloring.cpp
src/memory_coloring.cpp
+5
-3
test/eliminate_pad_test.cpp
test/eliminate_pad_test.cpp
+3
-4
test/gpu/quantization.cpp
test/gpu/quantization.cpp
+2
-2
test/inline_module_test.cpp
test/inline_module_test.cpp
+0
-1
test/insert_pad_test.cpp
test/insert_pad_test.cpp
+6
-5
test/layout_nhwc.cpp
test/layout_nhwc.cpp
+0
-1
test/onnx/onnx_rnn_test.cpp
test/onnx/onnx_rnn_test.cpp
+1
-1
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+0
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+51
-36
test/pad_calc_test.cpp
test/pad_calc_test.cpp
+0
-1
test/quantization.cpp
test/quantization.cpp
+0
-1
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+35
-33
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+6
-7
test/verify/gemm_literal.cpp
test/verify/gemm_literal.cpp
+2
-2
No files found.
src/memory_coloring.cpp
View file @
e6290061
...
...
@@ -23,9 +23,9 @@
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
...
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
// Replace zero allocation
...
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
0
}}),
mem
);
}
// Remove scratch parameter if its not used
...
...
test/eliminate_pad_test.cpp
View file @
e6290061
...
...
@@ -27,7 +27,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/common
.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
...
...
@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx
::
shape
s_weights
{
migraphx
::
shape
::
int32_type
,
{
4
,
channels
,
3
,
3
}};
std
::
vector
<
int32_t
>
weights
(
4
*
channels
*
3
*
3
);
auto
l_weights
=
m
.
add_literal
(
migraphx
::
literal
{
s_weights
,
weights
});
migraphx
::
op
::
convolution
op
;
op
.
padding_mode
=
padding_mode
;
return
m
.
add_instruction
(
op
,
l_img
,
l_weights
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding_mode"
,
padding_mode
}}),
l_img
,
l_weights
);
}
TEST_CASE
(
rewrite_pad
)
...
...
test/gpu/quantization.cpp
View file @
e6290061
...
...
@@ -24,7 +24,7 @@
#include <iostream>
#include <vector>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/
operators
.hpp>
#include <migraphx/
make_op
.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
...
...
@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization)
migraphx
::
shape
sc
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
mm
->
add_instruction
(
migraphx
::
op
::
dot
{}
,
pa
,
pb
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
dot
"
)
,
pa
,
pb
);
return
p
;
};
...
...
test/inline_module_test.cpp
View file @
e6290061
...
...
@@ -26,7 +26,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
...
...
test/insert_pad_test.cpp
View file @
e6290061
...
...
@@ -26,8 +26,8 @@
#include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/common.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
...
...
@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx
::
shape
s_weights
{
migraphx
::
shape
::
int32_type
,
{
4
,
channels
,
3
,
3
}};
std
::
vector
<
int32_t
>
weights
(
4
*
channels
*
3
*
3
);
auto
l_weights
=
m
.
add_literal
(
migraphx
::
literal
{
s_weights
,
weights
});
migraphx
::
op
::
convolution
op
;
op
.
padding_mode
=
padding_mode
;
op
.
padding
=
{
0
,
0
,
1
,
1
};
return
m
.
add_instruction
(
op
,
l_img
,
l_weights
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding_mode"
,
padding_mode
},
{
"padding"
,
{
0
,
0
,
1
,
1
}}}),
l_img
,
l_weights
);
}
TEST_CASE
(
rewrite_pad
)
...
...
test/layout_nhwc.cpp
View file @
e6290061
...
...
@@ -24,7 +24,6 @@
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
...
...
test/onnx/onnx_rnn_test.cpp
View file @
e6290061
...
...
@@ -24,7 +24,7 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/op
erators
.hpp>
#include <migraphx/op
/common
.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
...
...
test/onnx/verify_onnx.cpp
View file @
e6290061
...
...
@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp>
...
...
test/op_shape_test.cpp
View file @
e6290061
...
...
@@ -24,7 +24,8 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream>
#include <migraphx/make_op.hpp>
...
...
@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{
std
::
vector
<
std
::
size_t
>
lens
{
1
,
1
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
}};
throws_shape
(
migraphx
::
op
::
broadcast
{
1
,
lens
},
input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}
})
,
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
2
,
2
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
1
,
2
}};
throws_shape
(
migraphx
::
op
::
broadcast
{
1
,
lens
},
input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}
})
,
input
);
}
{
...
...
@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input
);
}
template
<
class
T
>
void
test_softmax_variations
()
void
test_softmax_variations
(
const
std
::
string
&
name
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
0
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
0
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
1
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
2
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
2
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
T
{
3
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axis"
,
3
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
int
axis
=
4
;
throws_shape
(
T
{
axis
},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axis"
,
axis
}
})
,
input
);
}
}
TEST_CASE
(
logsoftmax
)
{
test_softmax_variations
<
migraphx
::
op
::
logsoftmax
>
();
}
TEST_CASE
(
logsoftmax
)
{
test_softmax_variations
(
"logsoftmax"
);
}
TEST_CASE
(
softmax
)
{
test_softmax_variations
(
"softmax"
);
}
TEST_CASE
(
lstm
)
{
...
...
@@ -2328,47 +2338,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape
(
migraphx
::
make_op
(
"dequantizelinear"
),
input
,
scales
,
zeros
);
}
template
<
class
T
>
void
test_reduce_ops
()
void
test_reduce_ops
(
const
std
::
string
&
name
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
T
{},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
migraphx
::
make_op
(
name
),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
T
{{
0
,
1
,
2
,
3
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
1
}},
T
{{
2
,
3
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
2
,
3
}}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
T
{{
0
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
,
5
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
}}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1
}},
T
{{
-
1
}},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
-
1
}}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
throws_shape
(
T
{
{
4
}},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
4
}}
})
,
input
);
}
}
// dynamic shape
template
<
class
T
>
void
test_dyn_reduce_ops
()
void
test_dyn_reduce_ops
(
const
std
::
string
&
name
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
2
,
3
,
{
3
}},
{
1
,
1
}})},
T
{
{
-
1
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
-
1
}}
})
,
input
);
}
{
...
...
@@ -2376,7 +2393,7 @@ void test_dyn_reduce_ops()
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
2
,
4
,
{
4
}}})},
T
{
{
0
}},
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
0
}}
})
,
input
);
}
{
...
...
@@ -2385,24 +2402,24 @@ void test_dyn_reduce_ops()
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
({{
1
,
1
},
{
1
,
1
}})},
T
{{}}
,
migraphx
::
make_op
(
name
)
,
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
{
3
}},
{
2
,
4
,
{
4
}}}};
throws_shape
(
T
{
{
4
}},
input
);
throws_shape
(
migraphx
::
make_op
(
name
,
{{
"axes"
,
{
4
}}
})
,
input
);
}
}
TEST_CASE
(
reduce_max
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_max
>
(
);
}
TEST_CASE
(
reduce_mean
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_mean
>
(
);
}
TEST_CASE
(
reduce_prod
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_prod
>
(
);
}
TEST_CASE
(
reduce_sum
)
{
test_reduce_ops
<
migraphx
::
op
::
reduce_sum
>
(
);
}
TEST_CASE
(
reduce_max
)
{
test_reduce_ops
(
"
reduce_max
"
);
}
TEST_CASE
(
reduce_mean
)
{
test_reduce_ops
(
"
reduce_mean
"
);
}
TEST_CASE
(
reduce_prod
)
{
test_reduce_ops
(
"
reduce_prod
"
);
}
TEST_CASE
(
reduce_sum
)
{
test_reduce_ops
(
"
reduce_sum
"
);
}
TEST_CASE
(
reduce_max_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_max
>
(
);
}
TEST_CASE
(
reduce_mean_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_mean
>
(
);
}
TEST_CASE
(
reduce_prod_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_prod
>
(
);
}
TEST_CASE
(
reduce_sum_dyn
)
{
test_dyn_reduce_ops
<
migraphx
::
op
::
reduce_sum
>
(
);
}
TEST_CASE
(
reduce_max_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_max
"
);
}
TEST_CASE
(
reduce_mean_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_mean
"
);
}
TEST_CASE
(
reduce_prod_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_prod
"
);
}
TEST_CASE
(
reduce_sum_dyn
)
{
test_dyn_reduce_ops
(
"
reduce_sum
"
);
}
TEST_CASE
(
reshape_shape
)
{
...
...
@@ -2962,8 +2979,6 @@ TEST_CASE(slice_dyn_shape5)
input
);
}
TEST_CASE
(
softmax
)
{
test_softmax_variations
<
migraphx
::
op
::
softmax
>
();
}
TEST_CASE
(
softmax_dyn0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
4
,
4
},
{
5
,
5
}}};
...
...
test/pad_calc_test.cpp
View file @
e6290061
...
...
@@ -22,7 +22,6 @@
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp>
#include "test.hpp"
...
...
test/quantization.cpp
View file @
e6290061
...
...
@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
...
...
test/simplify_algebra_test.cpp
View file @
e6290061
...
...
@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/
o
per
ators
.hpp>
#include <migraphx/per
mutation
.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
op
::
broadcast
b
{
1
,
{
1
,
2
,
3
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
2
,
3
,
3
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
outer
);
...
...
@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{
migraphx
::
shape
inner
{
migraphx
::
shape
::
int32_type
,
{
2
}};
migraphx
::
shape
outer
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
3
,
3
}};
migraphx
::
op
::
broadcast
b
{
1
,
{
1
,
2
,
3
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
2
,
3
,
3
}}
})
;
auto
create_program
=
[
&
]
{
migraphx
::
module
m
;
auto
x
=
m
.
add_parameter
(
"x"
,
outer
);
...
...
@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE
(
simplify_inner_broadcast1
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
...
...
@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE
(
simplify_inner_broadcast2
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
...
...
@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE
(
simplify_inner_broadcast_scalar
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
32
,
384
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
32
,
384
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
...
...
@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
384
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
}});
auto
yb
=
m2
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
1
,
384
}},
y
);
auto
yb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
384
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
...
...
@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE
(
simplify_inner_broadcast_different_dims
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{
{
2
,
384
,
768
}};
auto
b
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
2
,
384
,
768
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
...
...
@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
768
}});
auto
yb
=
m2
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
384
,
768
}},
y
);
auto
yb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
384
,
768
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
...
...
@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE
(
simplify_inner_broadcast_different_broadcasts
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
1
,
24
,
112
,
112
}};
auto
mb
=
migraphx
::
op
::
multibroadcast
{
{
1
,
24
,
112
,
112
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
24
,
112
,
112
}}
})
;
auto
mb
=
migraphx
::
make_op
(
"
multibroadcast
"
,
{{
"out_lens"
,
{
1
,
24
,
112
,
112
}}
})
;
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
24
}});
...
...
@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
...
...
@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
2
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
2
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
...
...
@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
...
...
@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
2
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
2
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
...
...
@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m1
.
add_parameter
(
"x"
,
s
);
auto
y
=
m1
.
add_parameter
(
"y"
,
s
);
auto
one
=
m1
.
add_literal
(
1
);
...
...
@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
2
,
1
,
4
,
5
}}
})
;
auto
x
=
m2
.
add_parameter
(
"x"
,
s
);
auto
y
=
m2
.
add_parameter
(
"y"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
...
...
@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
...
...
@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
r
=
migraphx
::
op
::
reshape
{
{
3
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
r
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
...
...
@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
2
}};
migraphx
::
module
m1
;
{
auto
r
=
migraphx
::
op
::
reshape
{
{
3
,
2
,
4
}};
auto
r
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
...
...
@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
...
...
@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
3
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
one
=
m2
.
add_literal
(
1
);
auto
two
=
m2
.
add_literal
(
2
);
...
...
@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
,
6
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
,
3
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
,
3
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
3
}},
{
"starts"
,
{
0
,
0
}},
{
"ends"
,
{
1
,
3
}}}),
...
...
@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
4
}};
migraphx
::
module
m1
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
1
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
1
,
4
}}
})
;
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx
::
module
m2
;
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
3
,
2
,
4
}};
auto
b
=
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
3
,
2
,
4
}}
})
;
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
...
...
test/simplify_reshapes_test.cpp
View file @
e6290061
...
...
@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
...
...
@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_multibroadcasts2
)
...
...
@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
0
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
0
);
}
TEST_CASE
(
concat_multibroadcasts3
)
...
...
@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
2
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
2
);
}
TEST_CASE
(
concat_multibroadcasts4
)
...
...
@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
3
);
}
TEST_CASE
(
concat_transpose2
)
...
...
@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_transpose3
)
...
...
@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_transpose4
)
...
...
test/verify/gemm_literal.cpp
View file @
e6290061
...
...
@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/
operators
.hpp>
#include <migraphx/
make_op
.hpp>
struct
gemm_literal
:
verify_program
<
gemm_literal
>
{
...
...
@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto
a
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
a_shape
));
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
dot
{}
,
a
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
dot
"
)
,
a
,
b
);
return
p
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment