Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f3a8933c
Commit
f3a8933c
authored
Nov 02, 2023
by
Paul
Browse files
Merge branch 'develop' into blas_tuning
parents
ca300bd6
b249fb8a
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
486 additions
and
76 deletions
+486
-76
test/onnx/mvn_rank_3_test.onnx
test/onnx/mvn_rank_3_test.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+128
-43
test/onnx/tril_batch_diff_k_test.onnx
test/onnx/tril_batch_diff_k_test.onnx
+0
-0
test/onnx/tril_neg_k_test.onnx
test/onnx/tril_neg_k_test.onnx
+0
-0
test/onnx/tril_out_k_test.onnx
test/onnx/tril_out_k_test.onnx
+0
-0
test/onnx/tril_row_one_test.onnx
test/onnx/tril_row_one_test.onnx
+0
-0
test/onnx/tril_test.onnx
test/onnx/tril_test.onnx
+0
-0
test/onnx/triu_batch_diff_k_test.onnx
test/onnx/triu_batch_diff_k_test.onnx
+15
-0
test/onnx/triu_neg_k_test.onnx
test/onnx/triu_neg_k_test.onnx
+13
-0
test/onnx/triu_out_k_test.onnx
test/onnx/triu_out_k_test.onnx
+13
-0
test/onnx/triu_row_one_test.onnx
test/onnx/triu_row_one_test.onnx
+13
-0
test/onnx/triu_test.onnx
test/onnx/triu_test.onnx
+11
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+175
-10
test/op_shape_test.cpp
test/op_shape_test.cpp
+17
-2
test/py/CMakeLists.txt
test/py/CMakeLists.txt
+13
-4
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-15
test/py/requirements.txt
test/py/requirements.txt
+1
-1
test/ref/allocate.cpp
test/ref/allocate.cpp
+19
-1
test/ref/argmax.cpp
test/ref/argmax.cpp
+34
-0
test/ref/argmin.cpp
test/ref/argmin.cpp
+34
-0
No files found.
test/onnx/mvn_rank_3_test.onnx
0 → 100644
View file @
f3a8933c
File added
test/onnx/onnx_test.cpp
View file @
f3a8933c
...
...
@@ -42,11 +42,14 @@
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
);
migraphx
::
program
optimize_onnx
(
const
std
::
string
&
name
,
bool
run_passes
=
false
)
{
migraphx
::
onnx_options
options
;
...
...
@@ -181,6 +184,19 @@ TEST_CASE(argmax_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
argmax_select_last_index_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
2
},
{
"select_last_index"
,
true
}}),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
2
}}}),
ins
);
auto
prog
=
optimize_onnx
(
"argmax_select_last_index_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
argmax_dyn_test
)
{
migraphx
::
program
p
;
...
...
@@ -210,6 +226,19 @@ TEST_CASE(argmin_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
argmin_select_last_index_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
3
},
{
"select_last_index"
,
true
}}),
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
3
}}}),
ins
);
auto
prog
=
optimize_onnx
(
"argmin_select_last_index_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
asin_test
)
{
migraphx
::
program
p
;
...
...
@@ -4501,6 +4530,66 @@ TEST_CASE(mean_integral_test)
EXPECT
(
p
==
prog
);
}
void
mvn_n_rank_test
(
std
::
vector
<
int64_t
>
axes
,
std
::
vector
<
size_t
>
input_shape
,
const
std
::
string
&
test_file
)
{
using
migraphx
::
make_op
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
data
=
mm
->
add_parameter
(
"data"
,
{
migraphx
::
shape
::
float_type
,
std
::
move
(
input_shape
)});
auto
data_mean
=
mm
->
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data
);
auto
data_mean_squared
=
add_common_op
(
*
mm
,
make_op
(
"mul"
),
{
data_mean
,
data_mean
});
auto
data_squared
=
add_common_op
(
*
mm
,
make_op
(
"mul"
),
{
data
,
data
});
auto
data_squared_mean
=
mm
->
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data_squared
);
auto
mean_sub
=
add_common_op
(
*
mm
,
make_op
(
"sub"
),
{
data_squared_mean
,
data_mean_squared
});
auto
std
=
add_common_op
(
*
mm
,
make_op
(
"sqrt"
),
{
mean_sub
});
auto
dividend
=
add_common_op
(
*
mm
,
make_op
(
"sub"
),
{
data
,
data_mean
});
auto
epsilon
=
mm
->
add_literal
({
migraphx
::
shape
::
float_type
,
{
1e-9
}});
auto
divisor
=
add_common_op
(
*
mm
,
make_op
(
"add"
),
{
std
,
epsilon
});
add_common_op
(
*
mm
,
make_op
(
"div"
),
{
dividend
,
divisor
});
auto
prog
=
optimize_onnx
(
test_file
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
mvn_default_axes_test
)
{
mvn_n_rank_test
({
0
,
2
,
3
},
{
2
,
2
,
2
,
2
},
"mvn_default_axes_test.onnx"
);
}
TEST_CASE
(
mvn_default_axes_rank_too_small_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_default_axes_rank_too_small_test.onnx"
);
}));
}
TEST_CASE
(
mvn_default_axes_rank_too_big_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_default_axes_rank_too_big_test.onnx"
);
}));
}
TEST_CASE
(
mvn_rank_2_test
)
{
mvn_n_rank_test
({
1
},
{
2
,
2
},
"mvn_rank_2_test.onnx"
);
}
TEST_CASE
(
mvn_rank_3_test
)
{
mvn_n_rank_test
({
0
,
1
},
{
2
,
2
,
2
},
"mvn_rank_3_test.onnx"
);
}
TEST_CASE
(
mvn_axes_rank_too_small_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_axes_rank_too_small_test.onnx"
);
}));
}
TEST_CASE
(
mvn_axes_rank_too_big_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_axes_rank_too_big_test.onnx"
);
}));
}
TEST_CASE
(
min_test
)
{
migraphx
::
program
p
;
...
...
@@ -5480,6 +5569,31 @@ TEST_CASE(qlinearmatmul_2D_test)
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
migraphx
::
instruction_ref
insert_quantizelinear_clip
(
migraphx
::
module
&
m
,
const
migraphx
::
instruction_ref
ins
,
const
migraphx
::
instruction_ref
round
,
const
migraphx
::
shape
s
,
const
int64_t
min_quant
,
const
int64_t
max_quant
)
{
migraphx
::
instruction_ref
min_arg
;
migraphx
::
instruction_ref
max_arg
;
if
(
migraphx
::
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
migraphx
::
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
migraphx
::
literal
(
s
,
max_data
));
}
else
{
min_arg
=
m
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
min_quant
}});
max_arg
=
m
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
max_quant
}});
}
return
migraphx
::
insert_common_op
(
m
,
ins
,
migraphx
::
make_op
(
"clip"
),
{
round
,
min_arg
,
max_arg
});
}
TEST_CASE
(
quantizelinear_test
)
{
migraphx
::
program
p
;
...
...
@@ -5488,16 +5602,10 @@ TEST_CASE(quantizelinear_test)
auto
l1
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
l1_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
}}}),
l1
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
0
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
255
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_mbcast
,
max_mbcast
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
clip
=
insert_quantizelinear_clip
(
*
mm
,
div
,
round
,
s
,
0
,
255
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
uint8_type
)}}),
...
...
@@ -5519,16 +5627,10 @@ TEST_CASE(quantizelinear_int32_test)
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l0
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
0
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
255
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_mbcast
,
max_mbcast
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
clip
=
insert_quantizelinear_clip
(
*
mm
,
div
,
round
,
s
,
0
,
255
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
uint8_type
)}}),
...
...
@@ -5555,15 +5657,9 @@ TEST_CASE(quantizelinear_zero_point_test)
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l2_mbcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_mbcast
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
-
128
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
127
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_mbcast
,
max_mbcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_mbcast
);
auto
s
=
round
->
get_shape
();
auto
clip
=
insert_quantizelinear_clip
(
*
mm
,
div
,
add
,
s
,
-
128
,
127
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
int8_type
)}}),
...
...
@@ -5594,15 +5690,9 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l2_bcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_bcast
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
-
128
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
127
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_mbcast
,
max_mbcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_bcast
);
auto
s
=
round
->
get_shape
();
auto
clip
=
insert_quantizelinear_clip
(
*
mm
,
div
,
add
,
s
,
-
128
,
127
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
int8_type
)}}),
...
...
@@ -8031,11 +8121,6 @@ TEST_CASE(transpose_gather_test)
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
TEST_CASE
(
trilu_neg_k_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"trilu_neg_k_test.onnx"
);
}));
}
TEST_CASE
(
undefined_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/tril_batch_diff_k_test.onnx
0 → 100644
View file @
f3a8933c
File added
test/onnx/tril_neg_k_test.onnx
0 → 100644
View file @
f3a8933c
File added
test/onnx/tril_out_k_test.onnx
0 → 100644
View file @
f3a8933c
File added
test/onnx/tril_row_one_test.onnx
0 → 100644
View file @
f3a8933c
File added
test/onnx/tril
u_lower
_test.onnx
→
test/onnx/tril_test.onnx
View file @
f3a8933c
No preview for this file type
test/onnx/tri
l
u_batch_diff_k_test.onnx
→
test/onnx/triu_batch_diff_k_test.onnx
View file @
f3a8933c
tri
l
u_batch_diff_k_test:
i
triu_batch_diff_k_test:
h
x
ky"Trilu
tri
l
u_batch_diff_k_test*
ky"Trilu
triu_batch_diff_k_test*
:BkZ
x
...
...
@@ -12,4 +12,4 @@
B
\ No newline at end of file
B
\ No newline at end of file
test/onnx/tri
l
u_neg_k_test.onnx
→
test/onnx/triu_neg_k_test.onnx
View file @
f3a8933c
tri
l
u_neg_k_test:
c
triu_neg_k_test:
b
x
ky"Trilu
tri
l
u_neg_k_test*:
ky"Trilu
triu_neg_k_test*:
BkZ
x
...
...
@@ -10,4 +10,4 @@
y
B
\ No newline at end of file
B
\ No newline at end of file
test/onnx/tri
l
u_out_k_test.onnx
→
test/onnx/triu_out_k_test.onnx
View file @
f3a8933c
tri
l
u_out_k_test:
Z
triu_out_k_test:
Y
x
ky"Trilu
tri
l
u_out_k_test*
ky"Trilu
triu_out_k_test*
:BkZ
x
...
...
@@ -10,4 +10,4 @@
y
B
\ No newline at end of file
B
\ No newline at end of file
test/onnx/tri
l
u_row_one_test.onnx
→
test/onnx/triu_row_one_test.onnx
View file @
f3a8933c
tri
l
u_row_one_test:
\
triu_row_one_test:
[
x
ky"Trilu
tri
l
u_row_one_test*
ky"Trilu
triu_row_one_test*
:BkZ
x
...
...
@@ -10,4 +10,4 @@
y
B
\ No newline at end of file
B
\ No newline at end of file
test/onnx/tri
l
u_test.onnx
→
test/onnx/triu_test.onnx
View file @
f3a8933c
trilu_test:E
triu_test:D
xy"Trilu
trilu_testZ
xy"Trilu triu_testZ
x
...
...
@@ -10,4 +8,4 @@ trilu_testZ
y
B
\ No newline at end of file
B
\ No newline at end of file
test/onnx/verify_onnx.cpp
View file @
f3a8933c
...
...
@@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
template
<
typename
T
=
float
>
std
::
vector
<
T
>
mvn_test
(
std
::
vector
<
size_t
>
data_lens
,
const
std
::
string
&
test_file
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
test_file
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
shape
data_shape
(
migraphx
::
shape
::
get_type
<
T
>
{},
std
::
move
(
data_lens
));
std
::
vector
<
T
>
data
(
data_shape
.
elements
());
std
::
iota
(
begin
(
data
),
end
(
data
),
0
);
migraphx
::
parameter_map
pm
;
pm
[
"data"
]
=
migraphx
::
argument
(
data_shape
,
data
.
data
());
auto
result
=
p
.
eval
(
pm
).
back
();
std
::
vector
<
T
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
return
result_vector
;
}
TEST_CASE
(
mvn_default_axes_test
)
{
auto
result
=
mvn_test
({
2
,
2
,
2
,
2
},
"mvn_default_axes_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1.32424438
,
-
1.08347268
,
-
0.84270097
,
-
0.60192927
,
-
1.32424438
,
-
1.08347268
,
-
0.84270097
,
-
0.60192927
,
0.60192927
,
0.84270097
,
1.08347268
,
1.32424438
,
0.60192927
,
0.84270097
,
1.08347268
,
1.32424438
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_default_axes_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
half
>
({
2
,
2
,
2
,
2
},
"mvn_default_axes_fp16_test.onnx"
);
std
::
vector
<
half
>
gold
{
half
{
-
1.324
},
half
{
-
1.084
},
half
{
-
0.843
},
half
{
-
0.602
},
half
{
-
1.324
},
half
{
-
1.084
},
half
{
-
0.843
},
half
{
-
0.602
},
half
{
0.602
},
half
{
0.843
},
half
{
1.084
},
half
{
1.324
},
half
{
0.602
},
half
{
0.843
},
half
{
1.084
},
half
{
1.324
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_2_test
)
{
auto
result
=
mvn_test
({
2
,
2
},
"mvn_rank_2_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1
,
1
,
-
1
,
1
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_2_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
migraphx
::
half
>
({
2
,
2
},
"mvn_rank_2_fp16_test.onnx"
);
std
::
vector
<
migraphx
::
half
>
gold
{
half
{
-
1
},
half
{
1
},
half
{
-
1
},
half
{
1
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_3_test
)
{
auto
result
=
mvn_test
({
2
,
2
,
2
},
"mvn_rank_3_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1.34164079
,
-
1.34164079
,
-
0.4472136
,
-
0.4472136
,
0.4472136
,
0.4472136
,
1.34164079
,
1.34164079
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_3_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
half
>
({
2
,
2
,
2
},
"mvn_rank_3_fp16_test.onnx"
);
std
::
vector
<
half
>
gold
{
half
{
-
1.342
},
half
{
-
1.342
},
half
{
-
0.4473
},
half
{
-
0.4473
},
half
{
0.4473
},
half
{
0.4473
},
half
{
1.342
},
half
{
1.342
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mod_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"mod_test.onnx"
);
...
...
@@ -2124,9 +2233,10 @@ std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::prog
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
return
result_vector
;
}
TEST_CASE
(
trilu_test
)
TEST_CASE
(
triu_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tri
l
u_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"triu_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
...
...
@@ -2135,9 +2245,9 @@ TEST_CASE(trilu_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tri
l
u_batch_diff_k_test
)
TEST_CASE
(
triu_batch_diff_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tri
l
u_batch_diff_k_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"triu_batch_diff_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
}},
p
);
...
...
@@ -2146,9 +2256,42 @@ TEST_CASE(trilu_batch_diff_k_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tril
u_lower
_test
)
TEST_CASE
(
tril_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"trilu_lower_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tril_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
std
::
vector
<
float
>
gold
=
{
1
,
0
,
0
,
0
,
5
,
6
,
0
,
0
,
9
,
10
,
11
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tril_batch_diff_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tril_batch_diff_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
2
,
2
,
3
}},
p
);
std
::
vector
<
float
>
gold
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
triu_neg_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"triu_neg_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
std
::
vector
<
float
>
gold
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
0
,
10
,
11
,
12
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tril_neg_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tril_neg_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
...
...
@@ -2157,9 +2300,9 @@ TEST_CASE(trilu_lower_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tri
l
u_out_k_test
)
TEST_CASE
(
triu_out_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tri
l
u_out_k_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"triu_out_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
...
...
@@ -2168,9 +2311,20 @@ TEST_CASE(trilu_out_k_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
trilu_row_one_test
)
TEST_CASE
(
tril_out_k_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tril_out_k_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
3
,
4
}},
p
);
std
::
vector
<
float
>
gold
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
triu_row_one_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tri
l
u_row_one_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"triu_row_one_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
p
);
...
...
@@ -2179,4 +2333,15 @@ TEST_CASE(trilu_row_one_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
tril_row_one_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"tril_row_one_test.onnx"
);
std
::
vector
<
float
>
result_vector
=
gen_trilu_test
({
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
p
);
std
::
vector
<
float
>
gold
=
{
1
,
2
,
0
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/op_shape_test.cpp
View file @
f3a8933c
...
...
@@ -88,7 +88,7 @@ TEST_CASE(allocate_static)
expect_shape
(
out_shape
,
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
out_shape
)}}));
}
TEST_CASE
(
allocate_static_input
_error
)
TEST_CASE
(
allocate_static_input
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int64_type
,
{
3
}};
migraphx
::
shape
out_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
...
...
@@ -116,7 +116,7 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input
);
}
TEST_CASE
(
allocate_dyn_no_input
_error
)
TEST_CASE
(
allocate_dyn_no_input
)
{
migraphx
::
shape
shape_attr
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
4
,
8
,
{
4
,
6
}},
{
4
,
8
},
{
4
,
6
}}};
...
...
@@ -124,6 +124,21 @@ TEST_CASE(allocate_dyn_no_input_error)
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
migraphx
::
to_value
(
shape_attr
)}}));
}
TEST_CASE
(
allocate_shape_and_buf_type_error
)
{
migraphx
::
shape
shape_attr
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
3
,
3
},
{
4
,
8
,
{
4
,
6
}},
{
4
,
8
},
{
4
,
6
}}};
throws_shape
(
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
migraphx
::
to_value
(
shape_attr
)},
{
"buf_type"
,
migraphx
::
shape
::
half_type
}}));
}
TEST_CASE
(
allocate_no_attr_error
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int64_type
,
{
4
}};
throws_shape
(
migraphx
::
make_op
(
"allocate"
),
input
);
}
TEST_CASE
(
argmax_axis0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
...
...
test/py/CMakeLists.txt
View file @
f3a8933c
...
...
@@ -28,6 +28,7 @@ set(VENV_ONNX ${CMAKE_BINARY_DIR}/test/py/venv-onnx)
set
(
REQUIREMENTS
${
CMAKE_CURRENT_SOURCE_DIR
}
/requirements.txt
)
set
(
REQUIREMENTS_ONNX
${
CMAKE_CURRENT_SOURCE_DIR
}
/requirements-onnx.txt
)
set
(
PYTHON_VERSION_TO_DISABLE_ONNX 3.6
)
option
(
MIGRAPHX_DISABLE_VIRTUAL_ENV
"Disable python virtual environments"
OFF
)
function
(
add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE
)
...
...
@@ -61,23 +62,31 @@ function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR)
"PYTHONMALLOC=debug"
"MALLOC_CHECK_=3"
)
set
(
PYTHON_EXECUTABLE
${
VENV_DIR
}
/
${
PYTHON_VERSION
}
/bin/python
)
if
(
MIGRAPHX_DISABLE_VIRTUAL_ENV
)
set
(
PYTHON_EXECUTABLE
${
PYTHON_
${
PYTHON_VERSION
}
_EXECUTABLE
}
)
else
()
set
(
PYTHON_EXECUTABLE
${
VENV_DIR
}
/
${
PYTHON_VERSION
}
/bin/python
)
endif
()
if
(
NOT
(
${
FIXTURE_NAME
}
STREQUAL
"onnx"
AND
${
PYTHON_VERSION
}
STREQUAL
${
PYTHON_VERSION_TO_DISABLE_ONNX
}
))
add_test
(
NAME test_py_
${
PYTHON_VERSION
}
_
${
NAME
}
COMMAND
${
ENV_COMMAND
}
${
PYTHON_EXECUTABLE
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SCRIPT
}
${
ARGN
}
)
set_tests_properties
(
test_py_
${
PYTHON_VERSION
}
_
${
NAME
}
PROPERTIES FIXTURES_REQUIRED
${
FIXTURE_NAME
}
_
${
PYTHON_VERSION
}
_VENV
)
add_custom_target
(
test_py_
${
PYTHON_VERSION
}
_
${
NAME
}
COMMAND
${
ENV_COMMAND
}
${
PYTHON_EXECUTABLE
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SCRIPT
}
${
ARGN
}
COMMENT
"
${
PYTHON_EXECUTABLE
}
${
SCRIPT
}
"
)
if
(
NOT MIGRAPHX_DISABLE_VIRTUAL_ENV
)
set_tests_properties
(
test_py_
${
PYTHON_VERSION
}
_
${
NAME
}
PROPERTIES FIXTURES_REQUIRED
${
FIXTURE_NAME
}
_
${
PYTHON_VERSION
}
_VENV
)
endif
()
endif
()
endforeach
()
endfunction
()
add_dependencies
(
tests migraphx_py
)
add_dependencies
(
check migraphx_py
)
add_py_venv_fixture
(
common
${
VENV
}
${
REQUIREMENTS
}
)
add_py_venv_fixture
(
onnx
${
VENV_ONNX
}
${
REQUIREMENTS_ONNX
}
)
if
(
NOT MIGRAPHX_DISABLE_VIRTUAL_ENV
)
add_py_venv_fixture
(
common
${
VENV
}
${
REQUIREMENTS
}
)
add_py_venv_fixture
(
onnx
${
VENV_ONNX
}
${
REQUIREMENTS_ONNX
}
)
endif
()
add_py_test
(
ref test_cpu.py common
${
VENV
}
WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
save_load test_save_load.py common
${
VENV
}
WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
...
...
test/py/onnx_backend_test.py
View file @
f3a8933c
...
...
@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest):
def
disabled_tests_onnx_1_7_0
(
backend_test
):
# fails
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_argmax_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmax_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_negative_axis_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_argmin_no_keepdims_example_select_last_index_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_0_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_axis_1_cpu'
)
backend_test
.
exclude
(
r
'test_logsoftmax_default_axis_cpu'
)
...
...
@@ -154,7 +144,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_maxunpool_export_without_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int32_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int8_cpu'
)
backend_test
.
exclude
(
r
'test_mvn_cpu'
)
backend_test
.
exclude
(
r
'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu'
)
...
...
@@ -591,9 +580,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_lstm_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_tril_cpu'
)
backend_test
.
exclude
(
r
'test_tril_one_row_neg_cpu'
)
backend_test
.
exclude
(
r
'test_tril_square_cpu'
)
# from OnnxBackendPyTorchConvertedModelTest
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
backend_test
.
exclude
(
r
'test_MaxPool2d_stride_padding_dilation_cpu'
)
...
...
@@ -803,7 +789,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
backend_test
.
exclude
(
r
'test_group_normalization_example_cpu'
)
backend_test
.
exclude
(
r
'test_group_normalization_example_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_mish_cpu'
)
backend_test
.
exclude
(
r
'test_mvn_expanded_ver18_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_tensor_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_tensor_cpu'
)
...
...
test/py/requirements.txt
View file @
f3a8933c
...
...
@@ -22,4 +22,4 @@
# THE SOFTWARE.
#####################################################################################
numpy==1.21.6
\ No newline at end of file
numpy==1.19.5
\ No newline at end of file
test/ref/allocate.cpp
View file @
f3a8933c
...
...
@@ -30,7 +30,7 @@
#include <test.hpp>
TEST_CASE
(
allocate_dyn
)
TEST_CASE
(
allocate_dyn
0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -47,3 +47,21 @@ TEST_CASE(allocate_dyn)
migraphx
::
shape
sresult
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
4
}};
result
.
visit
([
&
](
auto
output
)
{
EXPECT
(
output
.
get_shape
()
==
sresult
);
});
}
TEST_CASE
(
allocate_dyn1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
int64_type
,
{
4
}};
migraphx
::
shape
out_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
4
}};
auto
out_dims
=
mm
->
add_parameter
(
"out_dims"
,
s
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"allocate"
,
{{
"shape"
,
migraphx
::
to_value
(
out_shape
)}}),
out_dims
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
std
::
vector
<
int64_t
>
data
=
{
2
,
3
,
4
,
4
};
params
[
"out_dims"
]
=
migraphx
::
argument
(
s
,
data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
result
.
visit
([
&
](
auto
output
)
{
EXPECT
(
output
.
get_shape
()
==
out_shape
);
});
}
test/ref/argmax.cpp
View file @
f3a8933c
...
...
@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmax_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmax_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
2.0305
,
-
1.853
,
2.0305
,
-
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
test/ref/argmin.cpp
View file @
f3a8933c
...
...
@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape)
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmin_test_select_last_index_0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
true
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_select_last_index_1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
-
2.0305
,
0.853
,
-
2.0305
,
1.5706
,
0.7545
,
0.7545
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
1
},
{
"select_last_index"
,
false
}}),
dl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
res_gold
));
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment