Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
29820def
"ollama/llm/llama.cpp/examples/rpc" did not exist on "22cb4ffc026b1fb71549031f174dc92f3751db56"
Commit
29820def
authored
Sep 21, 2023
by
Paul
Browse files
Merge
parents
6aa89319
be33669b
Changes
147
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
242 additions
and
185 deletions
+242
-185
src/targets/gpu/include/migraphx/gpu/mlir.hpp
src/targets/gpu/include/migraphx/gpu/mlir.hpp
+2
-1
src/targets/gpu/jit/mlir.cpp
src/targets/gpu/jit/mlir.cpp
+1
-3
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+9
-6
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+21
-18
src/verify_args.cpp
src/verify_args.cpp
+19
-17
test/gpu/codegen_literal.cpp
test/gpu/codegen_literal.cpp
+1
-1
test/gpu/manage_host_buffer.cpp
test/gpu/manage_host_buffer.cpp
+2
-2
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+2
-1
test/gpu/quantization.cpp
test/gpu/quantization.cpp
+6
-4
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+96
-96
test/op_shape_test.cpp
test/op_shape_test.cpp
+44
-0
test/quantization.cpp
test/quantization.cpp
+8
-5
test/ref/abs.cpp
test/ref/abs.cpp
+3
-3
test/ref/acos.cpp
test/ref/acos.cpp
+3
-3
test/ref/acosh.cpp
test/ref/acosh.cpp
+3
-3
test/ref/add.cpp
test/ref/add.cpp
+7
-7
test/ref/allocate.cpp
test/ref/allocate.cpp
+1
-1
test/ref/argmax.cpp
test/ref/argmax.cpp
+7
-7
test/ref/argmin.cpp
test/ref/argmin.cpp
+6
-6
No files found.
src/targets/gpu/include/migraphx/gpu/mlir.hpp
View file @
29820def
...
@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
...
@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
MIGRAPHX_GPU_EXPORT
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
);
const
std
::
vector
<
shape
>&
inputs
,
bool
exhaustive
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/mlir.cpp
View file @
29820def
...
@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
...
@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const
operation
&
,
const
operation
&
,
bool
exhaustive
)
const
bool
exhaustive
)
const
{
{
if
(
not
exhaustive
)
return
nullopt
;
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
*
smod
=
ins
->
module_inputs
().
front
();
auto
*
smod
=
ins
->
module_inputs
().
front
();
return
get_tuning_config_mlir
(
ctx
,
*
smod
,
shapes
);
return
get_tuning_config_mlir
(
ctx
,
*
smod
,
shapes
,
exhaustive
);
}
}
};
};
...
...
src/targets/gpu/mlir.cpp
View file @
29820def
...
@@ -682,11 +682,12 @@ struct mlir_program
...
@@ -682,11 +682,12 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
*
str
);
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
*
str
);
}
}
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
tuning_config
get_tuning_config
(
bool
exhaustive
)
MIGRAPHX_TIDY_CONST
{
{
tuning_config
tc
;
tuning_config
tc
;
run_high_level_pipeline
();
run_high_level_pipeline
();
auto
tuning_mode
=
RocmlirTuningParamSetKindFull
;
auto
tuning_mode
=
exhaustive
?
RocmlirTuningParamSetKindFull
:
RocmlirTuningParamSetKindQuick
;
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
...
@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
...
@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
}
}
tuning_config
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
module
m
,
const
std
::
vector
<
shape
>&
inputs
,
bool
exhaustive
)
{
{
adjust_param_shapes
(
m
,
inputs
);
adjust_param_shapes
(
m
,
inputs
);
mlir_program
mp
;
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
mp
.
parse
(
m
);
return
mp
.
get_tuning_config
();
return
mp
.
get_tuning_config
(
exhaustive
);
}
}
#else
#else
...
@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
...
@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return
m
.
end
();
return
m
.
end
();
}
}
tuning_config
get_tuning_config_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
shape
>&
)
tuning_config
get_tuning_config_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
shape
>&
,
bool
)
{
{
return
{};
return
{};
}
}
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
29820def
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
...
@@ -45,40 +46,42 @@ struct layernorm_base
...
@@ -45,40 +46,42 @@ struct layernorm_base
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
{
std
::
size_t
nargs
=
1
;
std
::
size_t
nargs
=
N
;
if
(
not
mods
.
empty
())
if
(
not
mods
.
empty
())
{
{
auto
*
pm
=
mods
.
front
();
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
nargs
+
=
pm
->
get_parameter_names
().
size
()
-
1
;
}
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
);
auto
s
=
inputs
.
a
t
(
0
);
auto
s
=
inputs
.
fron
t
();
auto
t
=
s
.
type
();
auto
t
=
s
.
type
();
if
(
not
mods
.
empty
())
if
(
not
mods
.
empty
())
t
=
mods
.
front
()
->
get_output_shapes
().
front
().
type
();
t
=
mods
.
front
()
->
get_output_shapes
().
front
().
type
();
if
(
s
.
scalar
())
{
// Scalar output if all inputs are scalar
return
s
;
if
(
inputs
.
front
().
elements
()
==
1
and
}
all_of
(
inputs
,
[](
const
auto
&
ss
)
{
return
ss
.
scalar
();
}))
else
if
(
s
.
broadcasted
())
return
inputs
.
front
();
{
auto
l_s
=
shape
::
from_permutation
(
return
{
t
,
s
.
lens
()};
t
,
s
.
lens
(),
find_permutation
(
std
::
vector
<
shape
>
(
inputs
.
begin
(),
inputs
.
begin
()
+
N
)));
}
// just prelayernorm or preadd_layernorm
else
if
(
nargs
<=
N
)
{
return
l_s
;
return
s
.
with_lens
(
t
,
s
.
lens
());
// else, layernorm + pointwise fusion, preserve layout of fused op
}
std
::
vector
<
shape
>
lp_s
(
inputs
.
begin
()
+
N
,
inputs
.
end
());
lp_s
.
insert
(
lp_s
.
begin
(),
l_s
);
return
shape
::
from_permutation
(
t
,
s
.
lens
(),
find_permutation
(
lp_s
));
}
}
};
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
struct
layernorm
:
layernorm_base
<
layernorm
,
1
>
{
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
2
>
{
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
};
...
...
src/verify_args.cpp
View file @
29820def
...
@@ -28,19 +28,20 @@ namespace migraphx {
...
@@ -28,19 +28,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
verify_args
(
const
std
::
string
&
name
,
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
ref_arg
,
const
argument
&
target_arg
,
const
argument
&
target_arg
,
double
tolerance
)
const
verify
::
expected
<
argument
>&
ref_arg
,
verify
::
tolerance
tols
)
{
{
bool
passed
=
true
;
bool
passed
=
true
;
visit_all
(
ref_arg
,
target_arg
)([
&
](
auto
ref
,
auto
target
)
{
visit_all
(
ref_arg
.
data
(),
target_arg
)([
&
](
auto
ref
,
auto
target
)
{
double
error
;
double
rms_error
;
passed
=
verify
::
verify_range
(
ref
,
target
,
tolerance
,
&
error
);
passed
=
verify
::
verify_range_with_tolerance
(
target
,
verify
::
expected
{
ref
},
tols
,
&
rms_error
);
if
(
not
passed
)
if
(
not
passed
)
{
{
// TODO: Check for nans
// TODO: Check for nans
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"
e
rror: "
<<
error
<<
std
::
endl
;
std
::
cout
<<
"
RMS E
rror: "
<<
rms_
error
<<
std
::
endl
;
if
(
ref
.
size
()
<
32
)
if
(
ref
.
size
()
<
32
)
std
::
cout
<<
"ref:"
<<
ref
<<
std
::
endl
;
std
::
cout
<<
"ref:"
<<
ref
<<
std
::
endl
;
if
(
target
.
size
()
<
32
)
if
(
target
.
size
()
<
32
)
...
@@ -78,16 +79,6 @@ bool verify_args(const std::string& name,
...
@@ -78,16 +79,6 @@ bool verify_args(const std::string& name,
if
(
verify
::
range_zero
(
target
))
if
(
verify
::
range_zero
(
target
))
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
// auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < verify::range_distance(ref))
// {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl;
// }
auto
ref_nan_idx
=
find_idx
(
ref
,
verify
::
not_finite
);
auto
ref_nan_idx
=
find_idx
(
ref
,
verify
::
not_finite
);
if
(
ref_nan_idx
>=
0
)
if
(
ref_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
...
@@ -97,11 +88,22 @@ bool verify_args(const std::string& name,
...
@@ -97,11 +88,22 @@ bool verify_args(const std::string& name,
if
(
target_nan_idx
>=
0
)
if
(
target_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
//
std::cout << std::endl;
std
::
cout
<<
"MIGraphX verification passed successfully."
<<
std
::
endl
;
}
}
});
});
return
passed
;
return
passed
;
}
}
bool
verify_args_with_tolerance
(
const
std
::
string
&
name
,
const
argument
&
target_arg
,
const
verify
::
expected
<
argument
>&
ref_arg
,
std
::
size_t
tolerance
)
{
double
rms_tol
=
0.001
;
target_arg
.
visit
([
&
](
auto
ta
)
{
rms_tol
=
verify
::
get_rms_tol
(
ta
,
tolerance
);
});
verify
::
tolerance
tols
{
rms_tol
};
return
verify_args
(
name
,
target_arg
,
ref_arg
,
tols
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
test/gpu/codegen_literal.cpp
View file @
29820def
...
@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
...
@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
run_prog
(
p
,
gpu_t
,
m
,
gpu_result
);
run_prog
(
p
,
gpu_t
,
m
,
gpu_result
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
ref
_result
,
gpu
_result
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
gpu
_result
,
ref
_result
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/gpu/manage_host_buffer.cpp
View file @
29820def
...
@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy)
...
@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy)
migraphx
::
parameter_map
pp
;
migraphx
::
parameter_map
pp
;
std
::
vector
<
float
>
a_vec
(
ss
.
elements
(),
-
1
);
std
::
vector
<
float
>
a_vec
(
ss
.
elements
(),
-
1
);
std
::
vector
<
float
>
b_vec
(
ss
.
elements
(),
2
);
std
::
vector
<
float
>
b_vec
(
ss
.
elements
(),
2
);
std
::
vector
<
float
>
c_vec
(
ss
.
elements
(),
0
);
pp
[
"a"
]
=
migraphx
::
argument
(
ss
,
a_vec
.
data
());
pp
[
"a"
]
=
migraphx
::
argument
(
ss
,
a_vec
.
data
());
pp
[
"b"
]
=
migraphx
::
argument
(
ss
,
b_vec
.
data
());
pp
[
"b"
]
=
migraphx
::
argument
(
ss
,
b_vec
.
data
());
std
::
vector
<
float
>
gpu_result
;
std
::
vector
<
float
>
gpu_result
;
...
@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy)
...
@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy)
auto
result
=
p
.
eval
(
pp
).
back
();
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
results_vector
(
ss
.
elements
(),
-
1
);
std
::
vector
<
float
>
results_vector
(
ss
.
elements
(),
-
1
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
c_vec
,
results_vector
));
std
::
vector
<
float
>
gold_vec
(
ss
.
elements
(),
0
);
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold_vec
));
}
}
TEST_CASE
(
arguments_lifetime
)
TEST_CASE
(
arguments_lifetime
)
...
...
test/gpu/mlir.cpp
View file @
29820def
...
@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir)
...
@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir)
auto
inputs
=
generate_params
(
ref
);
auto
inputs
=
generate_params
(
ref
);
auto
mlir
=
create_program_from_mlir
(
mmlir
);
auto
mlir
=
create_program_from_mlir
(
mmlir
);
return
migraphx
::
verify_args
(
"mlir"
,
run_ref
(
ref
,
inputs
),
run_gpu
(
mlir
,
inputs
));
return
migraphx
::
verify_args_with_tolerance
(
"mlir"
,
run_gpu
(
mlir
,
inputs
),
migraphx
::
verify
::
expected
{
run_ref
(
ref
,
inputs
)});
}
}
TEST_CASE
(
conv
)
TEST_CASE
(
conv
)
...
...
test/gpu/quantization.cpp
View file @
29820def
...
@@ -40,7 +40,6 @@
...
@@ -40,7 +40,6 @@
TEST_CASE
(
gpu_target_copy
)
TEST_CASE
(
gpu_target_copy
)
{
{
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
migraphx
::
target
ref_t
=
migraphx
::
make_target
(
"ref"
);
migraphx
::
shape
s
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
5
}};
auto
ref_arg_orig
=
migraphx
::
generate_argument
(
s
,
0x123456L
);
auto
ref_arg_orig
=
migraphx
::
generate_argument
(
s
,
0x123456L
);
...
@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
...
@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
std
::
vector
<
int8_t
>
val_final
;
std
::
vector
<
int8_t
>
val_final
;
ref_arg_final
.
visit
([
&
](
auto
v
)
{
val_final
.
assign
(
v
.
begin
(),
v
.
end
());
});
ref_arg_final
.
visit
([
&
](
auto
v
)
{
val_final
.
assign
(
v
.
begin
(),
v
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
val_orig
,
val_final
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
val_orig
,
val_final
));
}
}
TEST_CASE
(
int8_quantization
)
TEST_CASE
(
int8_quantization
)
...
@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
...
@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
// earlier stage.
if
(
migraphx
::
gpu
::
mlir_enabled
())
if
(
migraphx
::
gpu
::
mlir_enabled
())
EXPECT
(
migraphx
::
verify
::
verify_range
(
ref_result
,
gpu_result
,
1e5
));
EXPECT
(
migraphx
::
verify
::
verify_range_with_tolerance
(
gpu_result
,
migraphx
::
verify
::
expected
{
ref_result
},
migraphx
::
verify
::
tolerance
{
0.01
}));
else
else
EXPECT
(
migraphx
::
verify
::
verify_range
(
ref
_result
,
gpu
_result
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
gpu
_result
,
ref
_result
));
}
}
}
}
...
...
test/onnx/.onnxrt-commit
View file @
29820def
ae74a517b62baa6d973e46b5b51ac9a640512c46
377f959c69e9f213cd4a8c71a5e80162a412989a
test/onnx/verify_onnx.cpp
View file @
29820def
This diff is collapsed.
Click to expand it.
test/op_shape_test.cpp
View file @
29820def
...
@@ -890,6 +890,50 @@ TEST_CASE(flatten_dyn_axis4)
...
@@ -890,6 +890,50 @@ TEST_CASE(flatten_dyn_axis4)
input
);
input
);
}
}
TEST_CASE
(
fill_static_int
)
{
migraphx
::
shape
default_value
{
migraphx
::
shape
::
int64_type
,
{
1
},
{
0
}};
migraphx
::
shape
data
{
migraphx
::
shape
::
int64_type
,
{
3
,
4
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
,
4
,
4
}},
migraphx
::
make_op
(
"fill"
),
default_value
,
data
);
}
TEST_CASE
(
fill_static_float
)
{
migraphx
::
shape
default_value
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}};
migraphx
::
shape
data
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"fill"
),
default_value
,
data
);
}
TEST_CASE
(
fill_dyn_int
)
{
migraphx
::
shape
default_value
{
migraphx
::
shape
::
int64_type
,
{
1
},
{
0
}};
migraphx
::
shape
data
{
migraphx
::
shape
::
int64_type
,
{{
1
,
4
},
{
4
,
8
,
{
4
,
6
,
8
}},
{
4
,
8
,
{
4
,
6
,
8
}}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{{
1
,
4
},
{
4
,
8
,
{
4
,
6
,
8
}},
{
4
,
8
,
{
4
,
6
,
8
}}}},
migraphx
::
make_op
(
"fill"
),
default_value
,
data
);
}
TEST_CASE
(
fill_dyn_float
)
{
migraphx
::
shape
default_value
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}};
migraphx
::
shape
data
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
8
,
{
4
,
6
,
8
}},
{
4
,
8
,
{
4
,
6
,
8
}}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
8
,
{
4
,
6
,
8
}},
{
4
,
8
,
{
4
,
6
,
8
}}}},
migraphx
::
make_op
(
"fill"
),
default_value
,
data
);
}
TEST_CASE
(
gather
)
TEST_CASE
(
gather
)
{
{
{
{
...
...
test/quantization.cpp
View file @
29820def
...
@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy)
...
@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy)
std
::
vector
<
float
>
orig_result
;
std
::
vector
<
float
>
orig_result
;
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
ref_result
,
orig_result
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
ref_result
,
orig_result
));
}
}
}
}
...
@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot)
...
@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
quant_result
,
no_quant_result
,
30000
));
EXPECT
(
migraphx
::
verify
::
verify_range_with_tolerance
(
quant_result
,
migraphx
::
verify
::
expected
{
no_quant_result
},
migraphx
::
verify
::
tolerance
{
0.003
}));
}
}
}
}
...
@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv)
...
@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
no_quant_result
);
EXPECT
(
migraphx
::
verify
::
verify_range
(
quant_result
,
no_quant_result
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
quant_result
,
no_quant_result
));
}
}
}
}
...
@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture)
...
@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture)
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
vec
,
cap_vec
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
vec
,
cap_vec
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/ref/abs.cpp
View file @
29820def
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -42,7 +42,7 @@ TEST_CASE(abs_test)
...
@@ -42,7 +42,7 @@ TEST_CASE(abs_test)
std
::
vector
<
float
>
results_vector
(
4
);
std
::
vector
<
float
>
results_vector
(
4
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
2
,
3
,
4
};
std
::
vector
<
float
>
gold
{
1
,
2
,
3
,
4
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
abs_dyn_test
)
TEST_CASE
(
abs_dyn_test
)
...
@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test)
...
@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test)
std
::
vector
<
float
>
results_vector
(
4
);
std
::
vector
<
float
>
results_vector
(
4
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
2
,
3
,
4
};
std
::
vector
<
float
>
gold
{
1
,
2
,
3
,
4
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
test/ref/acos.cpp
View file @
29820def
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -45,7 +45,7 @@ TEST_CASE(acos_test)
...
@@ -45,7 +45,7 @@ TEST_CASE(acos_test)
std
::
vector
<
float
>
gold
=
data
;
std
::
vector
<
float
>
gold
=
data
;
std
::
transform
(
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acosf
(
n
);
});
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acosf
(
n
);
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
acos_dyn_test
)
TEST_CASE
(
acos_dyn_test
)
...
@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test)
...
@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test)
std
::
vector
<
float
>
gold
=
input_data
;
std
::
vector
<
float
>
gold
=
input_data
;
std
::
transform
(
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acosf
(
n
);
});
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acosf
(
n
);
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
test/ref/acosh.cpp
View file @
29820def
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -45,7 +45,7 @@ TEST_CASE(acosh_test)
...
@@ -45,7 +45,7 @@ TEST_CASE(acosh_test)
std
::
vector
<
float
>
gold
=
data
;
std
::
vector
<
float
>
gold
=
data
;
std
::
transform
(
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acoshf
(
n
);
});
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acoshf
(
n
);
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
acosh_dyn_test
)
TEST_CASE
(
acosh_dyn_test
)
...
@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test)
...
@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test)
std
::
vector
<
float
>
gold
=
input_data
;
std
::
vector
<
float
>
gold
=
input_data
;
std
::
transform
(
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acoshf
(
n
);
});
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
acoshf
(
n
);
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
test/ref/add.cpp
View file @
29820def
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test)
...
@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test)
std
::
vector
<
float
>
results_vector
(
12
);
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
add_multibroadcast_test
)
TEST_CASE
(
add_multibroadcast_test
)
...
@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test)
...
@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test)
std
::
vector
<
float
>
results_vector
(
12
);
std
::
vector
<
float
>
results_vector
(
12
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
std
::
vector
<
float
>
gold
=
{
0
,
1
,
2
,
2
,
3
,
4
,
4
,
5
,
6
,
6
,
7
,
8
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
add_test
)
TEST_CASE
(
add_test
)
...
@@ -91,7 +91,7 @@ TEST_CASE(add_test)
...
@@ -91,7 +91,7 @@ TEST_CASE(add_test)
std
::
vector
<
float
>
results_vector
(
3
);
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
2
,
4
};
std
::
vector
<
float
>
gold
=
{
0
,
2
,
4
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
add_dyn_test
)
TEST_CASE
(
add_dyn_test
)
...
@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test)
...
@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test)
std
::
vector
<
float
>
results_vector
(
3
);
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
2
,
4
};
std
::
vector
<
float
>
gold
=
{
0
,
2
,
4
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
fp16_test
)
TEST_CASE
(
fp16_test
)
...
@@ -134,7 +134,7 @@ TEST_CASE(fp16_test)
...
@@ -134,7 +134,7 @@ TEST_CASE(fp16_test)
std
::
vector
<
migraphx
::
half
>
results_vector
(
1
);
std
::
vector
<
migraphx
::
half
>
results_vector
(
1
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
migraphx
::
half
>
gold
{
c
};
std
::
vector
<
migraphx
::
half
>
gold
{
c
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
fp32_fp16_test
)
TEST_CASE
(
fp32_fp16_test
)
...
@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test)
...
@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
res
;
std
::
vector
<
float
>
res
;
result
.
visit
([
&
](
auto
output
)
{
res
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
res
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
res
,
gold_res
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
res
,
gold_res
));
};
};
test_case
({
"all"
});
test_case
({
"all"
});
...
...
test/ref/allocate.cpp
View file @
29820def
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/argmax.cpp
View file @
29820def
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0)
...
@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_test_1
)
TEST_CASE
(
argmax_test_1
)
...
@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1)
...
@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_test_2
)
TEST_CASE
(
argmax_test_2
)
...
@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2)
...
@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_test_neg_2
)
TEST_CASE
(
argmax_test_neg_2
)
...
@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2)
...
@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_dyn_test
)
TEST_CASE
(
argmax_dyn_test
)
...
@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test)
...
@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
};
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_test_nonstd_shape
)
TEST_CASE
(
argmax_test_nonstd_shape
)
...
@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape)
...
@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold_vec
));
}
}
test/ref/argmin.cpp
View file @
29820def
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
@@ -47,7 +47,7 @@ TEST_CASE(argmin_test_0)
...
@@ -47,7 +47,7 @@ TEST_CASE(argmin_test_0)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmin_test_1
)
TEST_CASE
(
argmin_test_1
)
...
@@ -66,7 +66,7 @@ TEST_CASE(argmin_test_1)
...
@@ -66,7 +66,7 @@ TEST_CASE(argmin_test_1)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmin_test_2
)
TEST_CASE
(
argmin_test_2
)
...
@@ -85,7 +85,7 @@ TEST_CASE(argmin_test_2)
...
@@ -85,7 +85,7 @@ TEST_CASE(argmin_test_2)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmin_test_neg_1
)
TEST_CASE
(
argmin_test_neg_1
)
...
@@ -104,7 +104,7 @@ TEST_CASE(argmin_test_neg_1)
...
@@ -104,7 +104,7 @@ TEST_CASE(argmin_test_neg_1)
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmin_test_nonstd_shape
)
TEST_CASE
(
argmin_test_nonstd_shape
)
...
@@ -123,5 +123,5 @@ TEST_CASE(argmin_test_nonstd_shape)
...
@@ -123,5 +123,5 @@ TEST_CASE(argmin_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify_
rms_
range
(
result_vec
,
res_gold_vec
));
}
}
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment