Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f3a8933c
Commit
f3a8933c
authored
Nov 02, 2023
by
Paul
Browse files
Merge branch 'develop' into blas_tuning
parents
ca300bd6
b249fb8a
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
146 additions
and
72 deletions
+146
-72
test/rewrite_quantization_test.cpp
test/rewrite_quantization_test.cpp
+8
-1
test/targets.cpp
test/targets.cpp
+0
-4
test/verify/quant_conv_1.cpp
test/verify/quant_conv_1.cpp
+1
-1
test/verify/quant_conv_2.cpp
test/verify/quant_conv_2.cpp
+1
-1
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+100
-52
tools/accuracy/accuracy_checker.py
tools/accuracy/accuracy_checker.py
+36
-13
No files found.
test/rewrite_quantization_test.cpp
View file @
f3a8933c
...
@@ -31,10 +31,13 @@
...
@@ -31,10 +31,13 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
);
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
...
@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear)
...
@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear)
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
// ensure clip literals created in quantized program are scalar
// ensure clip literals created in quantized program are scalar
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
// unless CK workarounds are enabled
if
(
migraphx
::
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
else
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
}
}
TEST_CASE
(
dequantizelinear
)
TEST_CASE
(
dequantizelinear
)
...
...
test/targets.cpp
View file @
f3a8933c
...
@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
...
@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
TEST_CASE
(
targets
)
TEST_CASE
(
targets
)
{
{
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto
ref_target
=
migraphx
::
make_target
(
"ref"
);
auto
ref_target
=
migraphx
::
make_target
(
"ref"
);
#endif
auto
ts
=
migraphx
::
get_targets
();
auto
ts
=
migraphx
::
get_targets
();
EXPECT
(
ts
.
size
()
>=
1
);
EXPECT
(
ts
.
size
()
>=
1
);
}
}
...
...
test/verify/quant_conv_
default_mode
.cpp
→
test/verify/quant_conv_
1
.cpp
View file @
f3a8933c
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_
default_mode
:
verify_program
<
quant_conv_
default_mode
>
struct
quant_conv_
1
:
verify_program
<
quant_conv_
1
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
...
test/verify/quant_conv_
int8x4_default
.cpp
→
test/verify/quant_conv_
2
.cpp
View file @
f3a8933c
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct
quant_conv_
int8x4_default
:
verify_program
<
quant_conv_
int8x4_default
>
struct
quant_conv_
2
:
verify_program
<
quant_conv_
2
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
...
test/verify/test_arg_ops.cpp
View file @
f3a8933c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -29,8 +29,8 @@
...
@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
template
<
class
T
,
int
Axis
,
bool
LastIndex
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
LastIndex
,
NonStdShape
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
...
@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break
;
break
;
default:
break
;
default:
break
;
}
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
mm
->
add_instruction
(
T
{
Axis
,
LastIndex
},
param
);
return
p
;
return
p
;
}
}
};
};
// transpose argmax tests
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
0
>;
// transpose argmin tests
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
0
>;
// broadcast argmax tests
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
1
>;
// broadcast argmin tests
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
1
>;
// slice argmax tests
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
2
>;
// slice argmin tests
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
2
>;
// default case, standard shape argmax tests
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
false
,
3
>;
// default case, standard shape argmin tests
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
false
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
true
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
false
,
3
>;
tools/accuracy/accuracy_checker.py
View file @
f3a8933c
#####################################################################################
#####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -52,6 +52,12 @@ def parse_args():
...
@@ -52,6 +52,12 @@ def parse_args():
parser
.
add_argument
(
'--fill0'
,
parser
.
add_argument
(
'--fill0'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'fill all arguments with a value of 0'
)
help
=
'fill all arguments with a value of 0'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'quantize MIGraphX model to fp16'
)
parser
.
add_argument
(
'--argmax'
,
action
=
'store_true'
,
help
=
'use argmax for accuracy'
)
parser
.
add_argument
(
'--verbose'
,
parser
.
add_argument
(
'--verbose'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'show verbose information (for debugging)'
)
help
=
'show verbose information (for debugging)'
)
...
@@ -105,7 +111,7 @@ def parse_args():
...
@@ -105,7 +111,7 @@ def parse_args():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
,
parser
# taken from ../test_runner.py
# taken from ../test_runner.py
...
@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
...
@@ -113,6 +119,7 @@ def check_correctness(gold_outputs,
outputs
,
outputs
,
rtol
=
1e-3
,
rtol
=
1e-3
,
atol
=
1e-3
,
atol
=
1e-3
,
use_argmax
=
False
,
verbose
=
False
):
verbose
=
False
):
if
len
(
gold_outputs
)
!=
len
(
outputs
):
if
len
(
gold_outputs
)
!=
len
(
outputs
):
print
(
'Number of outputs {} is not equal to expected number {}'
.
format
(
print
(
'Number of outputs {} is not equal to expected number {}'
.
format
(
...
@@ -121,18 +128,29 @@ def check_correctness(gold_outputs,
...
@@ -121,18 +128,29 @@ def check_correctness(gold_outputs,
out_num
=
len
(
gold_outputs
)
out_num
=
len
(
gold_outputs
)
ret
=
True
ret
=
True
for
i
in
range
(
out_num
):
if
not
np
.
allclose
(
gold_outputs
[
i
],
outputs
[
i
],
rtol
,
atol
):
if
not
use_argmax
:
for
i
in
range
(
out_num
):
if
not
np
.
allclose
(
gold_outputs
[
i
],
outputs
[
i
],
rtol
,
atol
):
ret
=
False
if
verbose
:
print
(
'
\n
Output {} is incorrect ...'
.
format
(
i
))
print
(
'Expected value:
\n
{}'
.
format
(
gold_outputs
[
i
]))
print
(
'......'
)
print
(
'Actual value:
\n
{}
\n
'
.
format
(
outputs
[
i
]))
else
:
print
(
'Outputs do not match'
)
break
else
:
golden_argmax
=
np
.
argmax
(
gold_outputs
)
actual_argmax
=
np
.
argmax
(
outputs
)
if
actual_argmax
!=
golden_argmax
:
ret
=
False
ret
=
False
print
(
'
\n
Output argmax is incorrect ...'
)
if
verbose
:
if
verbose
:
print
(
'
\n
Output {} is incorrect ...'
.
format
(
i
))
print
(
'Expected argmax value:
\n
{}'
.
format
(
golden_argmax
))
print
(
'Expected value:
\n
{}'
.
format
(
gold_outputs
[
i
]))
print
(
'......'
)
print
(
'......'
)
print
(
'Actual value:
\n
{}
\n
'
.
format
(
outputs
[
i
]))
print
(
'Actual argmax value:
\n
{}
\n
'
.
format
(
actual_argmax
))
else
:
print
(
'Outputs do not match'
)
break
return
ret
return
ret
...
@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
...
@@ -155,13 +173,14 @@ def get_np_datatype(in_type):
def
main
():
def
main
():
args
=
parse_args
()
args
,
parser
=
parse_args
()
use_onnx
=
True
use_onnx
=
True
if
args
.
onnx
==
None
:
if
args
.
onnx
==
None
:
use_onnx
=
False
use_onnx
=
False
if
not
use_onnx
and
args
.
tf
==
None
:
if
not
use_onnx
and
args
.
tf
==
None
:
print
(
'Error: please specify either an onnx or tf pb file'
)
print
(
'Error: please specify either an onnx or tf pb file'
)
parser
.
print_help
()
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
model_name
=
args
.
onnx
model_name
=
args
.
onnx
...
@@ -194,6 +213,9 @@ def main():
...
@@ -194,6 +213,9 @@ def main():
batch_size
=
batch
,
batch_size
=
batch
,
map_input_dims
=
input_dims
)
map_input_dims
=
input_dims
)
if
(
args
.
fp16
):
migraphx
.
quantize_fp16
(
model
)
if
args
.
verbose
:
if
args
.
verbose
:
print
(
model
)
print
(
model
)
...
@@ -300,7 +322,8 @@ def main():
...
@@ -300,7 +322,8 @@ def main():
if
not
args
.
ort_run
:
if
not
args
.
ort_run
:
is_correct
=
check_correctness
(
pred_fw
,
pred_migx
,
args
.
tolerance
,
is_correct
=
check_correctness
(
pred_fw
,
pred_migx
,
args
.
tolerance
,
args
.
tolerance
,
args
.
verbose
)
args
.
tolerance
,
args
.
argmax
,
args
.
verbose
)
verbose_string
=
' Rerun with --verbose for detailed information.'
\
verbose_string
=
' Rerun with --verbose for detailed information.'
\
if
not
args
.
verbose
else
''
if
not
args
.
verbose
else
''
if
is_correct
:
if
is_correct
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment