Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ac04f3cc
Commit
ac04f3cc
authored
Nov 10, 2023
by
Khalique Ahmed
Browse files
manual_merge
parents
d39c3343
d8011adf
Changes
539
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
636 additions
and
169 deletions
+636
-169
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+109
-10
test/split_single_dyn_dim_test.cpp
test/split_single_dyn_dim_test.cpp
+4
-64
test/stringutils.cpp
test/stringutils.cpp
+25
-0
test/targets.cpp
test/targets.cpp
+1
-5
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+10
-21
test/verify/CMakeLists.txt
test/verify/CMakeLists.txt
+3
-3
test/verify/ck_gemm_softmax_gemm.cpp
test/verify/ck_gemm_softmax_gemm.cpp
+56
-0
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+10
-5
test/verify/gemm_literal.cpp
test/verify/gemm_literal.cpp
+2
-2
test/verify/quant_conv_1.cpp
test/verify/quant_conv_1.cpp
+1
-1
test/verify/quant_conv_2.cpp
test/verify/quant_conv_2.cpp
+1
-1
test/verify/run_verify.cpp
test/verify/run_verify.cpp
+3
-3
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+100
-52
test/verify/test_flatten_dot_relu.cpp
test/verify/test_flatten_dot_relu.cpp
+46
-0
test/verify/test_isinf.cpp
test/verify/test_isinf.cpp
+52
-0
test/verify/test_isinf_broadcast.cpp
test/verify/test_isinf_broadcast.cpp
+48
-0
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+22
-2
test/verify/test_nearbyint.cpp
test/verify/test_nearbyint.cpp
+47
-0
test/verify/test_reduce_add.cpp
test/verify/test_reduce_add.cpp
+48
-0
test/verify/test_reduce_noop_add.cpp
test/verify/test_reduce_noop_add.cpp
+48
-0
No files found.
test/simplify_reshapes_test.cpp
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
...
...
@@ -68,6 +67,106 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return
m
;
}
TEST_CASE
(
broadcast_transpose
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
u1
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
,
2
,
3
}}}),
t1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_opt
)
{
// extra transpose from transformation will be optimized out
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
,
2
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
5
}}}),
u1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar_multi_use
)
{
// multibroadcast used more than once
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
auto
id
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb
);
m1
.
add_return
({
t1
,
id
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
id
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb2
);
m2
.
add_return
({
mb
,
id
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
double_contig
)
{
migraphx
::
program
p
;
...
...
@@ -477,7 +576,7 @@ TEST_CASE(concat_multibroadcasts1)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_multibroadcasts2
)
...
...
@@ -500,7 +599,7 @@ TEST_CASE(concat_multibroadcasts2)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
0
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
0
);
}
TEST_CASE
(
concat_multibroadcasts3
)
...
...
@@ -523,7 +622,7 @@ TEST_CASE(concat_multibroadcasts3)
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
2
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
2
);
}
TEST_CASE
(
concat_multibroadcasts4
)
...
...
@@ -559,7 +658,7 @@ TEST_CASE(concat_transpose1)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
3
);
}
TEST_CASE
(
concat_transpose2
)
...
...
@@ -583,7 +682,7 @@ TEST_CASE(concat_transpose2)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_transpose3
)
...
...
@@ -607,7 +706,7 @@ TEST_CASE(concat_transpose3)
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
new_concat
->
get_operator
().
to_value
()[
"axis"
].
to
<
int
>
()
==
1
);
}
TEST_CASE
(
concat_transpose4
)
...
...
@@ -1246,7 +1345,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto
cont_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
transpose_ins
);
auto
unsq_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
cont_ins
);
auto
round
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
round
"
),
unsq_ins
);
auto
round
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
nearbyint
"
),
unsq_ins
);
m1
.
add_instruction
(
pass_op
{},
round
);
}
run_pass
(
m1
);
...
...
@@ -1255,7 +1354,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
8
,
5
,
5
}});
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
auto
round
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"
round
"
),
transpose_ins
);
auto
round
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"
nearbyint
"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
round
);
auto
unsq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
cont_ins
);
...
...
test/split_single_dyn_dim_test.cpp
View file @
ac04f3cc
...
...
@@ -50,8 +50,8 @@ TEST_CASE(dynamic_batch)
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sm_shape
.
lens
()}}
),
literal_ins
);
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
sm_input
);
auto
add_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
submod
->
add_return
({
add_ins
});
...
...
@@ -107,8 +107,8 @@ TEST_CASE(multiple_outputs)
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sm_shape
.
lens
()}}
),
literal_ins
);
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
sm_input
);
auto
add0_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
auto
add1_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
sm_input
);
...
...
@@ -157,64 +157,4 @@ TEST_CASE(multiple_outputs)
EXPECT
(
p0
==
p1
);
}
TEST_CASE
(
broadcast_match
)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx
::
program
p0
;
{
auto
*
mm0
=
p0
.
get_main_module
();
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p0
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
4
}};
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
sm_shape
.
lens
()}}),
literal_ins
);
auto
add_ins
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input
,
broadcast_lit
);
submod
->
add_return
({
add_ins
});
return
submod
;
};
auto
*
dim1
=
create_submodule
(
1
,
"dim_1"
);
auto
*
dim2
=
create_submodule
(
2
,
"dim_2"
);
auto
*
dim3
=
create_submodule
(
3
,
"dim_3"
);
auto
*
dim4
=
create_submodule
(
4
,
"dim_4"
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input0
=
mm0
->
add_parameter
(
"data"
,
s
);
std
::
vector
<
migraphx
::
shape
>
sub_shapes
=
{};
sub_shapes
.
push_back
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}});
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
sub_shapes
};
auto
sm_ins
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
{
input0
},
{
dim1
,
dim2
,
dim3
,
dim4
});
auto
ret
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
sm_ins
);
mm0
->
add_return
({
ret
});
}
migraphx
::
program
p1
;
{
auto
*
mm1
=
p1
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input1
=
mm1
->
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
}}};
auto
literal_ins
=
mm1
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
,
5
,
4
,
3
}});
auto
broadcast_lit
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
literal_ins
,
input1
);
auto
add_ins
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
input1
,
broadcast_lit
);
mm1
->
add_return
({
add_ins
});
}
run_pass
(
p1
);
EXPECT
(
p0
==
p1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/stringutils.cpp
View file @
ac04f3cc
...
...
@@ -99,4 +99,29 @@ TEST_CASE(interpolate_string_custom3)
EXPECT
(
s
==
"****b****"
);
}
TEST_CASE
(
slit_string_simple1
)
{
std
::
string
input
=
"one,two,three"
;
auto
resuts
=
migraphx
::
split_string
(
input
,
','
);
EXPECT
(
resuts
.
size
()
==
3
);
EXPECT
(
resuts
.
front
()
==
"one"
);
EXPECT
(
resuts
.
back
()
==
"three"
);
}
TEST_CASE
(
slit_string_simple2
)
{
std
::
string
input
=
"one"
;
auto
resuts
=
migraphx
::
split_string
(
input
,
','
);
EXPECT
(
resuts
.
size
()
==
1
);
EXPECT
(
resuts
.
front
()
==
"one"
);
}
TEST_CASE
(
slit_string_simple3
)
{
std
::
string
input
=
"one two three"
;
auto
resuts
=
migraphx
::
split_string
(
input
,
','
);
EXPECT
(
resuts
.
size
()
==
1
);
EXPECT
(
resuts
.
front
()
==
"one two three"
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/targets.cpp
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
TEST_CASE
(
targets
)
{
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto
ref_target
=
migraphx
::
make_target
(
"ref"
);
#endif
auto
ts
=
migraphx
::
get_targets
();
EXPECT
(
ts
.
size
()
>=
1
);
}
...
...
test/tf/tf_test.cpp
View file @
ac04f3cc
...
...
@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp>
...
...
@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
1
,
0
}});
mm
->
add_literal
(
migraphx
::
literal
{
s0
,
{
2
,
-
1
}});
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
1
,
0
};
op
.
ends
=
{
3
,
10
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
mm
->
add_instruction
(
op
,
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
1
,
0
}},
{
"ends"
,
{
3
,
10
}},
{
"axes"
,
{
0
,
1
}}}),
l0
);
auto
prog
=
optimize_tf
(
"slice_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
...
...
@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
1
,
1
}});
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
std
::
size_t
num_axes
=
4
;
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
0
,
0
,
0
,
0
};
op
.
ends
=
{
1
,
1
,
1
,
5
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
auto
l2
=
mm
->
add_instruction
(
op
,
l1
);
auto
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
0
,
0
,
0
}},
{
"ends"
,
{
1
,
1
,
1
,
5
}},
{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
l1
);
auto
shrink_axis
=
1
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
shrink_axis
}}}),
l2
);
auto
prog
=
optimize_tf
(
"stridedslice_test.pb"
,
true
);
...
...
@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
10
,
3
,
3
}});
std
::
size_t
num_axes
=
4
;
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
0
,
1
,
1
,
0
};
op
.
ends
=
{
1
,
3
,
3
,
10
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
// add literals for starts, ends, and strides in tf (NHWC format)
mm
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
0
,
1
,
1
,
0
});
...
...
@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto
l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
auto
l2
=
mm
->
add_instruction
(
op
,
l1
);
auto
l2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
1
,
0
}},
{
"ends"
,
{
1
,
3
,
3
,
10
}},
{
"axes"
,
{
0
,
1
,
2
,
3
}}}),
l1
);
auto
l3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
l2
);
mm
->
add_return
({
l3
});
...
...
test/verify/CMakeLists.txt
View file @
ac04f3cc
...
...
@@ -25,14 +25,14 @@
file
(
GLOB VERIFY_TESTS CONFIGURE_DEPENDS *.cpp
)
add_executable
(
test_verify
${
VERIFY_TESTS
}
)
add_dependencies
(
test
s
test_verify
)
add_dependencies
(
check
test_verify
)
rocm_mark_as_
test
(
test_verify
)
rocm_install_test
(
TARGETS
test_verify
)
target_link_libraries
(
test_verify migraphx migraphx_all_targets
)
target_include_directories
(
test_verify PUBLIC ../include
)
rocm_clang_tidy_check
(
test_verify
)
foreach
(
SECTION general rnn
)
add_test
_command
(
test_verify_
${
SECTION
}
test_verify
${
SECTION
}
)
rocm_
add_test
(
NAME
test_verify_
${
SECTION
}
COMMAND
test_verify
${
SECTION
}
)
set_tests_properties
(
test_verify_
${
SECTION
}
PROPERTIES
COST 100
)
...
...
test/verify/ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_gemm_softmax_gemm
:
verify_program
<
ck_gemm_softmax_gemm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
m2_shape
.
elements
();
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
return
p
;
}
};
test/verify/
test_round
.cpp
→
test/verify/
gemm_2args_mm_8
.cpp
View file @
ac04f3cc
...
...
@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_round
:
verify_program
<
test_round
>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
32
,
32
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bb
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
param
);
return
p
;
}
;
}
};
test/verify/gemm_literal.cpp
View file @
ac04f3cc
...
...
@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/
operators
.hpp>
#include <migraphx/
make_op
.hpp>
struct
gemm_literal
:
verify_program
<
gemm_literal
>
{
...
...
@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto
a
=
mm
->
add_literal
(
migraphx
::
generate_literal
(
a_shape
));
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
op
::
dot
{}
,
a
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"
dot
"
)
,
a
,
b
);
return
p
;
}
...
...
test/verify/quant_conv_
default_mode
.cpp
→
test/verify/quant_conv_
1
.cpp
View file @
ac04f3cc
...
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_
default_mode
:
verify_program
<
quant_conv_
default_mode
>
struct
quant_conv_
1
:
verify_program
<
quant_conv_
1
>
{
migraphx
::
program
create_program
()
const
{
...
...
test/verify/quant_conv_
int8x4_default
.cpp
→
test/verify/quant_conv_
2
.cpp
View file @
ac04f3cc
...
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_
int8x4_default
:
verify_program
<
quant_conv_
int8x4_default
>
struct
quant_conv_
2
:
verify_program
<
quant_conv_
2
>
{
migraphx
::
program
create_program
()
const
{
...
...
test/verify/run_verify.cpp
View file @
ac04f3cc
...
...
@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block
template
<
class
Function
>
std
::
future
<
typename
std
::
result_of
<
Function
()
>::
type
>
detach_async
(
Function
&&
f
,
bool
parallel
=
true
)
std
::
future
<
std
::
invoke_result_t
<
Function
>>
detach_async
(
Function
&&
f
,
bool
parallel
=
true
)
{
if
(
parallel
)
{
...
...
@@ -251,7 +250,8 @@ void run_verify::verify(const std::string& name,
std
::
size_t
num
=
gold
.
size
();
for
(
std
::
size_t
i
=
0
;
((
i
<
num
)
and
passed
);
++
i
)
{
passed
&=
migraphx
::
verify_args
(
tname
,
gold
[
i
],
result
[
i
]);
passed
&=
migraphx
::
verify_args_with_tolerance
(
tname
,
result
[
i
],
migraphx
::
verify
::
expected
{
gold
[
i
]});
}
if
(
not
passed
or
migraphx
::
enabled
(
MIGRAPHX_TRACE_TEST
{}))
...
...
test/verify/test_arg_ops.cpp
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
template
<
class
T
,
int
Axis
,
bool
LastIndex
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
LastIndex
,
NonStdShape
>>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break
;
default:
break
;
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
,
LastIndex
},
param
);
return
p
;
}
};
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
0
>;
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
0
>;
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
1
>;
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
1
>;
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
2
>;
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
2
>;
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
3
>;
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
3
>;
test/verify/test_flatten_dot_relu.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_flatten_dot_relu
:
verify_program
<
test_flatten_dot_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"a"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
3
,
5
}});
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
3
}}),
a
);
auto
b
=
mm
->
add_parameter
(
"b"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
3
,
3
,
1
}});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
3
}}),
b
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
dot
);
return
p
;
}
};
test/verify/test_isinf.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <limits>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template
<
class
T
>
struct
test_isinf
:
verify_program
<
test_isinf
<
T
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
auto
min
=
std
::
numeric_limits
<
T
>::
min
();
auto
inf
=
std
::
numeric_limits
<
T
>::
infinity
();
auto
nan
=
std
::
numeric_limits
<
T
>::
quiet_NaN
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
get_type
<
T
>
(),
{
5
}});
std
::
vector
<
T
>
data0
{
inf
,
-
inf
,
max
,
min
,
nan
};
migraphx
::
shape
s1
{
migraphx
::
shape
::
get_type
<
T
>
(),
{
5
}};
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s1
,
data0
});
x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
x
,
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"isinf"
),
x
);
return
p
;
}
};
template
struct
test_isinf
<
migraphx
::
half
>;
template
struct
test_isinf
<
float
>;
test/verify/test_isinf_broadcast.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <limits>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_isinf_broadcast
:
verify_program
<
test_isinf_broadcast
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
}});
auto
s0
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
},
{
"out_lens"
,
s0
.
lens
()}}),
x
);
auto
inf
=
std
::
numeric_limits
<
float
>::
infinity
();
std
::
vector
<
float
>
data0
{
-
inf
,
inf
};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
2
}};
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s1
,
data0
});
x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
x
,
l0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"isinf"
),
x
);
return
p
;
}
};
test/verify/test_layernorm.cpp
View file @
ac04f3cc
...
...
@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
pow
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pow"
),
sub
,
exponent_mbcast
);
auto
var
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
}}}),
pow
);
auto
epsilon_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
dims
.
at
(
1
),
1
}}}),
epsilon
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
dims
.
at
(
0
),
dims
.
at
(
1
),
1
}}}),
epsilon
);
auto
add_epsilon
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
var
,
epsilon_mbcast
);
auto
sqrt
=
m
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
add_epsilon
);
auto
sqrt_mbcast
=
...
...
@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
div
=
m
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sqrt_mbcast
);
auto
scale_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
scale
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mbcast
,
div
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
div
,
scale_mbcast
);
auto
bias_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
bias
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mul
,
bias_mbcast
);
...
...
@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return
p
;
}
};
struct
test_add_layernorm_add_gemm_nonstd
:
verify_program
<
test_add_layernorm_add_gemm_nonstd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
s
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
8
,
1
,
16
},
{
1
,
2
,
0
});
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
16
,
64
}});
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
layernorm_ins
=
add_layernorm
(
*
mm
,
add
,
s
.
lens
());
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
layernorm_ins
,
z
);
return
p
;
}
};
test/verify/test_nearbyint.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template
<
class
T
>
struct
test_nearbyint
:
verify_program
<
test_nearbyint
<
T
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
std
::
vector
<
float
>
tmp
{
-
4.5
,
-
3.5
,
0.5
,
2.5
,
3.5
};
std
::
vector
<
T
>
data
{
tmp
.
cbegin
(),
tmp
.
cend
()};
migraphx
::
shape
s1
{
migraphx
::
shape
::
get_type
<
T
>
(),
{
5
}};
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s1
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"isinf"
),
l0
);
return
p
;
};
};
template
struct
test_nearbyint
<
migraphx
::
half
>;
template
struct
test_nearbyint
<
float
>;
test/verify/test_reduce_add.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct
test_reduce_add
:
verify_program
<
test_reduce_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
1000
,
2
,
2
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
half_type
,
{
1
,
32
,
128
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
reduce_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
reduce_max
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reduce_mean
,
reduce_max
);
mm
->
add_return
({
add
});
return
p
;
};
};
test/verify/test_reduce_noop_add.cpp
0 → 100644
View file @
ac04f3cc
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct
test_reduce_noop_add
:
verify_program
<
test_reduce_noop_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
1000
,
1
,
1
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
half_type
,
{
1
,
32
,
128
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
reduce_mean
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
reduce_max
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
2
,
3
}}}),
x
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reduce_mean
,
reduce_max
);
mm
->
add_return
({
add
});
return
p
;
};
};
Prev
1
…
22
23
24
25
26
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment