Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13d14c66
Commit
13d14c66
authored
Oct 24, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into dyn_resize_gather
parents
f4e7d9d9
d1abf06f
Changes
420
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
581 additions
and
206 deletions
+581
-206
test/ref/quant_convolution.cpp
test/ref/quant_convolution.cpp
+38
-38
test/ref/quantizelinear.cpp
test/ref/quantizelinear.cpp
+1
-1
test/ref/random_seed.cpp
test/ref/random_seed.cpp
+52
-0
test/ref/random_uniform.cpp
test/ref/random_uniform.cpp
+176
-0
test/ref/recip.cpp
test/ref/recip.cpp
+3
-3
test/ref/reduce_max.cpp
test/ref/reduce_max.cpp
+2
-2
test/ref/reduce_mean.cpp
test/ref/reduce_mean.cpp
+1
-1
test/ref/reduce_min.cpp
test/ref/reduce_min.cpp
+1
-1
test/ref/reduce_prod.cpp
test/ref/reduce_prod.cpp
+1
-1
test/ref/reduce_sum.cpp
test/ref/reduce_sum.cpp
+1
-1
test/ref/relu.cpp
test/ref/relu.cpp
+3
-3
test/ref/reshape.cpp
test/ref/reshape.cpp
+161
-13
test/ref/reverse.cpp
test/ref/reverse.cpp
+13
-13
test/ref/rnn_ops.cpp
test/ref/rnn_ops.cpp
+97
-98
test/ref/roialign.cpp
test/ref/roialign.cpp
+9
-9
test/ref/round.cpp
test/ref/round.cpp
+3
-3
test/ref/rsqrt.cpp
test/ref/rsqrt.cpp
+3
-3
test/ref/scalar.cpp
test/ref/scalar.cpp
+2
-2
test/ref/scatter.cpp
test/ref/scatter.cpp
+11
-11
test/ref/scatternd_add.cpp
test/ref/scatternd_add.cpp
+3
-3
No files found.
test/ref/quant_convolution.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -47,25 +47,25 @@ TEST_CASE(quant_conv2d_padding_stride_test)
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int32_t
>
s
=
{
4521
,
7014
,
7830
,
11952
,
10515
,
16734
,
19737
,
30906
,
13161
,
19542
,
19494
,
28800
,
34707
,
52590
,
54729
,
82746
};
std
::
vector
<
int32_t
>
gold
=
{
4521
,
7014
,
7830
,
11952
,
10515
,
16734
,
19737
,
30906
,
13161
,
19542
,
19494
,
28800
,
34707
,
52590
,
54729
,
82746
};
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
s
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
quant_conv2d_padding_test
)
...
...
@@ -83,8 +83,8 @@ TEST_CASE(quant_conv2d_padding_test)
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
,
1
}}}),
al
,
cl
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int32_t
>
s
=
{
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int32_t
>
gold
=
{
4521
,
6753
,
7014
,
4635
,
6858
,
10197
,
10548
,
6939
,
7830
,
11601
,
11952
,
7839
,
5007
,
7383
,
7590
,
4953
,
10515
,
15987
,
16734
,
11277
,
16821
,
25506
,
26586
,
17874
,
19737
,
29826
,
30906
,
20718
,
13593
,
20505
,
21198
,
14187
,
13161
,
19281
,
19542
,
12699
,
18522
,
27045
,
27396
,
...
...
@@ -93,7 +93,7 @@ TEST_CASE(quant_conv2d_padding_test)
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
s
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
quant_conv2d_test
)
...
...
@@ -114,24 +114,24 @@ TEST_CASE(quant_conv2d_test)
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
int32_t
>
s
=
{
10197
,
10548
,
11601
,
11952
,
25506
,
26586
,
29826
,
30906
,
27045
,
27396
,
28449
,
28800
,
77346
,
78426
,
81666
,
82746
};
std
::
vector
<
int32_t
>
gold
=
{
10197
,
10548
,
11601
,
11952
,
25506
,
26586
,
29826
,
30906
,
27045
,
27396
,
28449
,
28800
,
77346
,
78426
,
81666
,
82746
};
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
s
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/quantizelinear.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/random_seed.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <random>
#include <test.hpp>
/**
* Reference test for the random_seed operation
*/
TEST_CASE
(
random_seed_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_seed"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
uint64_t
>
result_vec1
(
1
);
result
.
visit
([
&
](
auto
output
)
{
result_vec1
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint64_t
>
result_vec2
(
1
);
// Identical calls should give different seeds every time with 1/(2^64) chance of a repeat.
// We don't analyze for true randomness.
result
=
p
.
eval
({}).
back
();
result
.
visit
([
&
](
auto
output
)
{
result_vec2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
result_vec1
[
0
]
!=
result_vec2
[
0
]);
}
test/ref/random_uniform.cpp
0 → 100644
View file @
13d14c66
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <random>
#include <test.hpp>
/**
* Reference test for the random_uniform operation. Also invokes the random_seed operation.
*/
TEST_CASE
(
random_uniform_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
uint64_t
seed
(
0
);
size_t
sample_size
(
200
);
// Shape of the random data
migraphx
::
shape
rs
{
migraphx
::
shape
::
float_type
,
{
1
,
sample_size
}};
// data tensor must be allocated at this point but does not need to be initialized.
std
::
vector
<
float
>
data
(
sample_size
);
auto
input
=
mm
->
add_literal
(
migraphx
::
literal
(
rs
,
data
));
// Runtime randomization seed
migraphx
::
shape
seed_shape
{
migraphx
::
shape
::
uint64_type
,
{
1
}};
std
::
vector
<
uint64_t
>
seed_data
{
seed
};
auto
seed_input
=
mm
->
add_literal
(
migraphx
::
literal
(
seed_shape
,
seed_data
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
// no params_map needed
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
result_vec
(
sample_size
);
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
// Compare result with the STL's mt19937 generator
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
std
::
vector
<
float
>
rand_samples
(
sample_size
);
std
::
generate
(
rand_samples
.
begin
(),
rand_samples
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
EXPECT
(
migraphx
::
verify
::
verify_range_with_tolerance
(
result_vec
,
migraphx
::
verify
::
expected
{
rand_samples
},
migraphx
::
verify
::
tolerance
{
0.00001
}));
}
TEST_CASE
(
random_uniform_int_test
)
{
// random uniform distribution with an integer type input shape
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
float
seed
(
0.1
);
size_t
sample_size
(
200
);
// Shape of the random data
migraphx
::
shape
rs
{
migraphx
::
shape
::
uint16_type
,
{
1
,
sample_size
}};
// data tensor must be allocated at this point but does not need to be initialized.
std
::
vector
<
uint16_t
>
data
(
sample_size
);
auto
input
=
mm
->
add_literal
(
migraphx
::
literal
(
rs
,
data
));
// Runtime randomization seed
migraphx
::
shape
seed_shape
{
migraphx
::
shape
::
float_type
,
{
1
}};
std
::
vector
<
float
>
seed_data
{
seed
};
auto
seed_input
=
mm
->
add_literal
(
migraphx
::
literal
(
seed_shape
,
seed_data
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params0
;
auto
result
=
p
.
eval
(
params0
).
back
();
std
::
vector
<
uint16_t
>
result_vec
(
sample_size
);
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
// Compare result with the STL's mt19937 generator
std
::
mt19937
gen
(
seed
);
std
::
uniform_int_distribution
<
uint16_t
>
dis
;
std
::
vector
<
uint16_t
>
gold_rand_samples
(
sample_size
);
std
::
generate
(
gold_rand_samples
.
begin
(),
gold_rand_samples
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
gold_rand_samples
));
}
TEST_CASE
(
random_uniform_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
uint64_t
seed
(
17
);
size_t
sample_size
(
200
);
// Shape of the random data
migraphx
::
shape
rs
{
migraphx
::
shape
::
float_type
,
{{
1
,
2
},
{
2
,
sample_size
+
1
}}};
auto
input
=
mm
->
add_parameter
(
"Input_1"
,
rs
);
// Runtime randomization seed
migraphx
::
shape
seed_shape
{
migraphx
::
shape
::
uint64_type
,
{
1
}};
auto
seed_input
=
mm
->
add_parameter
(
"Seed"
,
seed_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
,
{}),
seed_input
,
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
// Create a dummy input to hold the random data
migraphx
::
shape
input_fixed_shape1
{
migraphx
::
shape
::
float_type
,
{
sample_size
}};
migraphx
::
parameter_map
params0
;
params0
[
"Input_1"
]
=
migraphx
::
argument
(
input_fixed_shape1
);
std
::
vector
<
uint64_t
>
seed_data
=
{
seed
};
params0
[
"Seed"
]
=
migraphx
::
argument
(
seed_shape
,
seed_data
.
data
());
auto
result
=
p
.
eval
(
params0
).
back
();
std
::
vector
<
float
>
result_vec
(
sample_size
);
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
// Compare result with the STL's mt19937 generator
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
std
::
vector
<
float
>
gold_rand_samples
(
sample_size
);
std
::
generate
(
gold_rand_samples
.
begin
(),
gold_rand_samples
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vec
,
gold_rand_samples
));
}
TEST_CASE
(
random_uniform_and_seed_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
size_t
sample_size
(
20000
);
// Shape of the random data
migraphx
::
shape
rs
{
migraphx
::
shape
::
float_type
,
{{
1
,
2
},
{
2
,
sample_size
+
1
}}};
auto
input
=
mm
->
add_parameter
(
"Input_1"
,
rs
);
// Runtime randomization seed
auto
seed_input
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_seed"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
// Create a dummy input to hold the random data
migraphx
::
shape
input_fixed_shape1
{
migraphx
::
shape
::
float_type
,
{
sample_size
}};
migraphx
::
parameter_map
params0
;
params0
[
"Input_1"
]
=
migraphx
::
argument
(
input_fixed_shape1
);
auto
result
=
p
.
eval
(
params0
).
back
();
result
.
visit
([
&
](
auto
output
)
{
EXPECT
(
output
.
size
()
==
sample_size
);
});
// Do not check the content of the data since it's not repeatable
}
test/ref/recip.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -43,7 +43,7 @@ TEST_CASE(recip_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
2.0
f
,
10.0
f
,
2.0
f
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
recip_dyn_test
)
...
...
@@ -64,5 +64,5 @@ TEST_CASE(recip_dyn_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
2.0
f
,
10.0
f
,
2.0
f
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/reduce_max.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -64,7 +64,7 @@ TEST_CASE(reduce_max_dynamic_axis0)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
6
,
7
,
8
,
9
,
10
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
reduce_max_axis01
)
...
...
test/ref/reduce_mean.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/reduce_min.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/reduce_prod.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/reduce_sum.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/relu.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -42,7 +42,7 @@ TEST_CASE(relu_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.
f
,
0.
f
,
1.
f
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
relu_dyn_test
)
...
...
@@ -63,5 +63,5 @@ TEST_CASE(relu_dyn_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.
f
,
0.
f
,
1.
f
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/reshape.cpp
View file @
13d14c66
...
...
@@ -24,13 +24,13 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE
(
reshape_test0
)
TEST_CASE
(
reshape_
lazy_
test0
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
data
(
24
);
...
...
@@ -39,15 +39,15 @@ TEST_CASE(reshape_test0)
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
data
});
std
::
vector
<
int64_t
>
new_shape
=
{
8
,
3
,
1
,
1
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape
_lazy
"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
data
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
data
));
}
TEST_CASE
(
reshape_test1
)
TEST_CASE
(
reshape_
lazy_
test1
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
data
(
24
);
...
...
@@ -56,15 +56,15 @@ TEST_CASE(reshape_test1)
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
data
});
std
::
vector
<
int64_t
>
new_shape
=
{
1
,
3
,
4
,
2
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape
_lazy
"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
data
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
data
));
}
TEST_CASE
(
reshape_test2
)
TEST_CASE
(
reshape_
lazy_
test2
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
data
(
24
);
...
...
@@ -73,22 +73,22 @@ TEST_CASE(reshape_test2)
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
data
});
std
::
vector
<
int64_t
>
new_shape
=
{
1
,
2
,
3
,
4
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape
_lazy
"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
data
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
data
));
}
TEST_CASE
(
reshape_dyn_test
)
TEST_CASE
(
reshape_
lazy_
dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
std
::
vector
<
int64_t
>
new_shape
=
{
0
,
8
,
3
,
1
};
auto
input
=
mm
->
add_parameter
(
"X"
,
s
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
input
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape
_lazy
"
,
{{
"dims"
,
new_shape
}}),
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
data
(
48
);
...
...
@@ -99,5 +99,153 @@ TEST_CASE(reshape_dyn_test)
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
data
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
data
));
}
TEST_CASE
(
reshape_test0
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
gold
(
24
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3
);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
gold
});
std
::
vector
<
int64_t
>
new_shape
=
{
8
,
3
,
1
,
1
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_test1
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
gold
(
24
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3
);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
gold
});
std
::
vector
<
int64_t
>
new_shape
=
{
1
,
3
,
4
,
2
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_test2
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
gold
(
24
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3
);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_literal
(
migraphx
::
literal
{
a_shape
,
gold
});
std
::
vector
<
int64_t
>
new_shape
=
{
1
,
2
,
3
,
4
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
l
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_dyn_1in_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
std
::
vector
<
int64_t
>
new_shape
=
{
0
,
8
,
3
,
1
};
auto
input
=
mm
->
add_parameter
(
"X"
,
s
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
new_shape
}}),
input
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3
);
migraphx
::
parameter_map
params
;
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
input_fixed_shape
,
gold
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_2in_test0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
6
,
6
},
{
4
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
input_fixed_shape
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
EXPECT
(
result
.
get_shape
()
==
output_fixed_shape
);
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_2in_test1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
},
{
6
,
6
},
{
2
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
s_in
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
EXPECT
(
result
.
get_shape
()
==
output_fixed_shape
);
std
::
vector
<
float
>
results_vector
{};
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reshape_2in_elements_runtime_error
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_in
{
migraphx
::
shape
::
float_type
,
{
2
,
24
,
1
,
1
}};
migraphx
::
shape
s_out
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
},
{
6
,
6
},
{
2
,
4
},
{
1
,
1
}}};
auto
input
=
mm
->
add_parameter
(
"X"
,
s_in
);
auto
output_buffer
=
mm
->
add_parameter
(
"Y"
,
s_out
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
output_buffer
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
gold
(
48
);
std
::
iota
(
gold
.
begin
(),
gold
.
end
(),
-
3.
);
std
::
vector
<
float
>
buffer
(
48
);
std
::
iota
(
buffer
.
begin
(),
buffer
.
end
(),
0.
);
migraphx
::
parameter_map
params
;
// elements do not match up
migraphx
::
shape
output_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
2
,
1
}};
params
[
"X"
]
=
migraphx
::
argument
(
s_in
,
gold
.
data
());
params
[
"Y"
]
=
migraphx
::
argument
(
output_fixed_shape
,
buffer
.
data
());
EXPECT
(
test
::
throws
([
&
]
{
std
::
ignore
=
p
.
eval
(
params
).
back
();
}));
}
test/ref/reverse.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -44,9 +44,9 @@ TEST_CASE(reverse_test_axis0)
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
target_data
=
data
;
std
::
swap_ranges
(
target_data
.
begin
(),
target_data
.
begin
()
+
16
,
target_data
.
begin
()
+
16
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
target_data
));
std
::
vector
<
float
>
gold
=
data
;
std
::
swap_ranges
(
gold
.
begin
(),
gold
.
begin
()
+
16
,
gold
.
begin
()
+
16
);
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
reverse_test_axis1
)
...
...
@@ -63,10 +63,10 @@ TEST_CASE(reverse_test_axis1)
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
target_data
=
data
;
std
::
reverse
(
target_data
.
begin
(),
target_data
.
begin
()
+
16
);
std
::
reverse
(
target_data
.
end
()
-
16
,
target_data
.
end
());
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
target_data
));
std
::
vector
<
float
>
gold
=
data
;
std
::
reverse
(
gold
.
begin
(),
gold
.
begin
()
+
16
);
std
::
reverse
(
gold
.
end
()
-
16
,
gold
.
end
());
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
reverse_test_axis10
)
...
...
@@ -83,9 +83,9 @@ TEST_CASE(reverse_test_axis10)
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
target_data
=
data
;
std
::
reverse
(
target_data
.
begin
(),
target_data
.
begin
()
+
16
);
std
::
reverse
(
target_data
.
end
()
-
16
,
target_data
.
end
());
std
::
swap_ranges
(
target_data
.
begin
(),
target_data
.
begin
()
+
16
,
target_data
.
begin
()
+
16
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
target_data
));
std
::
vector
<
float
>
gold
=
data
;
std
::
reverse
(
gold
.
begin
(),
gold
.
begin
()
+
16
);
std
::
reverse
(
gold
.
end
()
-
16
,
gold
.
end
());
std
::
swap_ranges
(
gold
.
begin
(),
gold
.
begin
()
+
16
,
gold
.
begin
()
+
16
);
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/rnn_ops.cpp
View file @
13d14c66
...
...
@@ -21,18 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
instruction
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
...
...
@@ -150,8 +145,8 @@ TEST_CASE(rnn_forward)
-0.16477929,
-0.11893477};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
{
...
...
@@ -211,8 +206,8 @@ TEST_CASE(rnn_forward)
0.44193283,
-0.16477929,
-0.11893477};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
{
...
...
@@ -271,8 +266,8 @@ TEST_CASE(rnn_forward)
0};
std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 3 args
...
...
@@ -302,7 +297,7 @@ TEST_CASE(rnn_forward)
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// seq_len = 1
...
...
@@ -349,7 +344,7 @@ TEST_CASE(rnn_forward)
0.31708236,
0.13104209,
-0.18736027};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -443,7 +438,7 @@ TEST_CASE(rnn_reverse)
0.46251031,
-0.20639211,
0.37488942};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// rnn last output as program output
...
...
@@ -486,7 +481,7 @@ TEST_CASE(rnn_reverse)
0.44124447,
0.14365635,
0.14803654};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
...
...
@@ -549,8 +544,8 @@ TEST_CASE(rnn_reverse)
0.14365635,
0.14803654};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
...
...
@@ -611,8 +606,8 @@ TEST_CASE(rnn_reverse)
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
}
...
...
@@ -723,8 +718,8 @@ TEST_CASE(rnn_bidirectional)
0.14365635,
0.14803654};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// last rnn output for program output
...
...
@@ -789,8 +784,8 @@ TEST_CASE(rnn_bidirectional)
0.143656,
0.148037};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// 4 args
...
...
@@ -840,7 +835,7 @@ TEST_CASE(rnn_bidirectional)
0.14365635,
0.14803654};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// 3 args
...
...
@@ -875,7 +870,7 @@ TEST_CASE(rnn_bidirectional)
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
}
// concatenation of hidden state for program output
...
...
@@ -928,7 +923,7 @@ TEST_CASE(rnn_bidirectional)
-0.20639211,
0.37488942};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -1013,7 +1008,10 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
,
5e4
));
EXPECT(migraphx::verify::verify_range_with_tolerance(
last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
}
TEST_CASE(gru_forward)
...
...
@@ -1111,7 +1109,7 @@ TEST_CASE(gru_forward)
0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787,
-0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// last output for output
...
...
@@ -1157,7 +1155,7 @@ TEST_CASE(gru_forward)
0.51757574,
0.50380427};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// two rnn_last_hs_output operators after gru
...
...
@@ -1204,7 +1202,7 @@ TEST_CASE(gru_forward)
0.51757574,
0.50380427};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
...
...
@@ -1250,7 +1248,7 @@ TEST_CASE(gru_forward)
0.6014447,
0.43445644};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -1335,7 +1333,7 @@ TEST_CASE(gru_forward_args)
-0.232523, 0.00214573, 0.231693, -0.160475, -0.518952,
0.0467166, 0.12327, -0.374162, 0.137778, 0.251976};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
...
...
@@ -1378,7 +1376,7 @@ TEST_CASE(gru_forward_args)
-0.416866, 0.377186, 0.32922, 0.162214, -0.519973,
-0.140072, 0.465076, -0.229563, 0.500164, 0.195166};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
...
...
@@ -1422,7 +1420,7 @@ TEST_CASE(gru_forward_args)
-0.197, 0.0885705, 0.269396, -0.0414511, -0.515137,
-0.03075, 0.158326, -0.296488, 0.177983, 0.519498};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -1524,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.51757574,
0.50380427};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
...
...
@@ -1565,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663,
0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
...
...
@@ -1610,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.65615714,
0.53612584};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// seq length of 1
...
...
@@ -1660,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.6104771,
0.79759157};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -1776,8 +1774,8 @@ TEST_CASE(gru_reverse)
0.55703,
0.54711};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
// variable input sequence length
...
...
@@ -1837,8 +1835,8 @@ TEST_CASE(gru_reverse)
0.558397,
0.664423};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
...
...
@@ -1884,7 +1882,7 @@ TEST_CASE(gru_reverse)
0.646604,
0.463943};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// no activation function specified, so default is used.
...
...
@@ -1923,7 +1921,7 @@ TEST_CASE(gru_reverse)
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// seq length of 1
...
...
@@ -1973,7 +1971,7 @@ TEST_CASE(gru_reverse)
0.610477,
0.797592};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -2104,8 +2102,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
// same input sequence length, but shorter than max squence length
...
...
@@ -2173,8 +2171,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
// variable input sequence lengths
...
...
@@ -2232,8 +2230,8 @@ TEST_CASE(gru_bidirectional)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
0.182457, 0.304506, 0.313825, 0.397697, 0.300873};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
...
...
@@ -2273,7 +2271,7 @@ TEST_CASE(gru_bidirectional)
-0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289,
-0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -2375,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args)
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
...
...
@@ -2426,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args)
0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
...
...
@@ -2474,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args)
0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
-0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -2588,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
...
...
@@ -2631,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.132732, 0.477083, 0.802206, 0.626802};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
...
...
@@ -2675,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 3 activation functions specified
...
...
@@ -2715,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// 4 activation functions all specified
...
...
@@ -2763,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -2878,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
TEST_CASE(gru_fp16)
...
...
@@ -2988,7 +2986,8 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
,
5e4
));
EXPECT(migraphx::verify::verify_range_with_tolerance(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
}
TEST_CASE(lstm_forward)
...
...
@@ -3119,7 +3118,7 @@ TEST_CASE(lstm_forward)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// forward, last_output as program output
...
...
@@ -3172,7 +3171,7 @@ TEST_CASE(lstm_forward)
0.0342236,
-0.198664,
0.0702607};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
...
...
@@ -3225,7 +3224,7 @@ TEST_CASE(lstm_forward)
0.078598,
-0.64457,
0.119811};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
}
...
...
@@ -3347,7 +3346,7 @@ TEST_CASE(lstm_forward_more)
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// forward, 8 args
...
...
@@ -3396,7 +3395,7 @@ TEST_CASE(lstm_forward_more)
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
...
...
@@ -3458,7 +3457,7 @@ TEST_CASE(lstm_forward_more)
0.0342236,
-0.198664,
0.0702607};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// seq_len = 1
...
...
@@ -3516,7 +3515,7 @@ TEST_CASE(lstm_forward_more)
-0.121195,
-0.4065,
-0.252054};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
}
}
...
...
@@ -3646,7 +3645,7 @@ TEST_CASE(lstm_reverse)
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
...
...
@@ -3704,7 +3703,7 @@ TEST_CASE(lstm_reverse)
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// variable sequence lengths
...
...
@@ -3754,7 +3753,7 @@ TEST_CASE(lstm_reverse)
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
...
...
@@ -3796,7 +3795,7 @@ TEST_CASE(lstm_reverse)
0.141613,
0.348002,
0.667298};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
...
...
@@ -3835,7 +3834,7 @@ TEST_CASE(lstm_reverse)
0.141613,
0.348002,
0.667298};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
}
...
...
@@ -3953,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv)
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
...
...
@@ -3994,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv)
0.233866,
0.48646,
0.481844};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
...
...
@@ -4040,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv)
0.070535,
0.327809,
0.407388};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
}
...
...
@@ -4167,7 +4166,7 @@ TEST_CASE(lstm_bidirectional)
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// last hidden state as program output
...
...
@@ -4210,7 +4209,7 @@ TEST_CASE(lstm_bidirectional)
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// last cell output as program output
...
...
@@ -4253,7 +4252,7 @@ TEST_CASE(lstm_bidirectional)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
...
...
@@ -4296,7 +4295,7 @@ TEST_CASE(lstm_bidirectional)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// sequence length is 1, contenation of hidden state as program output
...
...
@@ -4333,7 +4332,7 @@ TEST_CASE(lstm_bidirectional)
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
}
...
...
@@ -4485,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
last_cell_data
,
last_cell_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
...
...
@@ -4572,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lho_data
,
lho_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_range
(
lco_data
,
lco_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_
rms_
range(lco_data, lco_data_gold));
}
}
...
...
@@ -4659,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, 1 actv func
...
...
@@ -4699,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, 2 actv func
...
...
@@ -4732,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, 4 actv func
...
...
@@ -4768,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, 5 actv func
...
...
@@ -4804,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
// 3 args, 6 actv func
...
...
@@ -4841,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT
(
migraphx
::
verify
::
verify_range
(
output_data
,
output_data_gold
));
EXPECT(migraphx::verify::verify_
rms_
range(output_data, output_data_gold));
}
}
...
...
@@ -4986,5 +4985,5 @@ TEST_CASE(lstm_fp16)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT
(
migraphx
::
verify
::
verify_range
(
hs_data
,
hs_data_gold
,
5e4
));
EXPECT(migraphx::verify::verify_
rms_
range(hs_data, hs_data_gold, 5e4));
}
test/ref/roialign.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -73,14 +73,14 @@ TEST_CASE(roialign_out_of_bound_test)
};
{
auto
p
=
create_program
(
"
output_
half_pixel"
);
auto
p
=
create_program
(
"half_pixel"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.0
f
,
0.0
f
,
0.0
f
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
...
...
@@ -130,7 +130,7 @@ TEST_CASE(roialign_test)
};
{
auto
p
=
create_program
();
auto
p
=
create_program
(
"output_half_pixel"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
...
...
@@ -150,11 +150,11 @@ TEST_CASE(roialign_test)
0.256580025
,
0.214098021
,
0.279604018
,
0.360000014
,
0.436488032
,
0.350427985
,
0.288755983
,
0.366139978
,
0.234920025
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
{
auto
p
=
create_program
(
"
output_
half_pixel"
);
auto
p
=
create_program
(
"half_pixel"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
...
...
@@ -171,11 +171,11 @@ TEST_CASE(roialign_test)
0.929997
,
0.66257
,
0.561664
,
0.481275
,
0.495449
,
0.666306
,
0.663573
,
0.372107
,
0.205603
,
0.192776
,
0.247849
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
{
auto
p
=
create_program
(
"
output_
half_pixel"
,
migraphx
::
op
::
pooling_mode
::
max
,
0
);
auto
p
=
create_program
(
"half_pixel"
,
migraphx
::
op
::
pooling_mode
::
max
,
0
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
...
...
@@ -192,6 +192,6 @@ TEST_CASE(roialign_test)
0.44757
,
0.351855
,
0.342265
,
0.244475
,
0.274841
,
0.553644
,
0.607176
,
0.202392
,
0.07425
,
0.066087
,
0.126279
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
test/ref/round.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -43,7 +43,7 @@ TEST_CASE(round_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
2.0
,
2.0
,
-
1.0
,
-
2.0
,
-
2.0
,
0.0
,
2.0
,
-
2.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
round_dyn_test
)
...
...
@@ -64,5 +64,5 @@ TEST_CASE(round_dyn_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
2.0
,
2.0
,
-
1.0
,
-
2.0
,
-
2.0
,
0.0
,
2.0
,
-
2.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/rsqrt.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -42,7 +42,7 @@ TEST_CASE(rsqrt_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
rsqrt_dyn_test
)
...
...
@@ -63,5 +63,5 @@ TEST_CASE(rsqrt_dyn_test)
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/scalar.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -77,5 +77,5 @@ TEST_CASE(imagescaler_test)
0.53
,
0.73
,
0.93
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
test/ref/scatter.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -64,7 +64,7 @@ TEST_CASE(scatter_ax0_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
2.0
,
1.1
,
0.0
,
1.0
,
0.0
,
2.2
,
0.0
,
2.1
,
1.2
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
...
...
@@ -78,7 +78,7 @@ TEST_CASE(scatter_ax_neg_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
2.0
,
1.1
,
0.0
,
1.0
,
0.0
,
2.2
,
0.0
,
2.1
,
1.2
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
...
...
@@ -91,7 +91,7 @@ TEST_CASE(scatter_ax1_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.1
,
1.0
,
1.2
,
2.0
,
2.2
,
2.1
,
0.0
,
0.0
,
0.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
...
...
@@ -128,7 +128,7 @@ TEST_CASE(scatter_reduction1_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_none
=
{
1.0
,
1.1
,
3.0
,
2.1
,
5.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_none
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_none
));
}
}
...
...
@@ -142,7 +142,7 @@ TEST_CASE(scatter_reduction2_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_mul
=
{
1.0
,
2.2
,
3.0
,
8.4
,
5.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_mul
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_mul
));
}
}
TEST_CASE
(
scatter_reduction3_test
)
...
...
@@ -155,7 +155,7 @@ TEST_CASE(scatter_reduction3_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_add
=
{
1.0
,
3.1
,
3.0
,
6.1
,
5.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_add
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_add
));
}
}
...
...
@@ -184,7 +184,7 @@ TEST_CASE(scatter_reduction_3x3_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_a2
=
{
4.1
,
4.0
,
4.2
,
10.0
,
10.2
,
10.1
,
3.0
,
3.0
,
3.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_a2
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_a2
));
}
}
...
...
@@ -221,7 +221,7 @@ TEST_CASE(scatter_reduction_3x3_xpose1_test)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_none2
=
{
1.1
,
7.0
,
3.0
,
1.0
,
7.2
,
3.0
,
1.2
,
7.1
,
3.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_none2
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_none2
));
}
}
...
...
@@ -236,7 +236,7 @@ TEST_CASE(scatter_reduction_3x3_xpose2_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_a3
=
{
4.1
,
10.0
,
3.0
,
4.0
,
10.2
,
3.0
,
4.2
,
10.1
,
3.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_a3
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_a3
));
}
}
...
...
@@ -250,6 +250,6 @@ TEST_CASE(scatter_reduction_3x3_xpose3_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold_mul2
=
{
3.3
,
21.0
,
3.0
,
3.0
,
21.6
,
3.0
,
3.6
,
21.3
,
3.0
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold_mul2
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold_mul2
));
}
}
test/ref/scatternd_add.cpp
View file @
13d14c66
...
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
...
...
@@ -57,7 +57,7 @@ TEST_CASE(scatternd_add_reduction_test)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
3
,
3
,
5
,
6
,
6
,
7
,
9
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
TEST_CASE
(
scatternd_reduction_dyn_test
)
...
...
@@ -102,5 +102,5 @@ TEST_CASE(scatternd_reduction_dyn_test)
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
,
6
,
6
,
5
,
4
,
3
,
4
,
5
,
6
,
7
,
9
,
10
,
11
,
12
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
Prev
1
…
14
15
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment