Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5139b930
Unverified
Commit
5139b930
authored
Oct 17, 2023
by
Charlie Lin
Committed by
GitHub
Oct 17, 2023
Browse files
Change driver verify to check for fp16 and --fp16 (#2334)
parent
94bda243
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
38 deletions
+60
-38
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+7
-0
src/driver/main.cpp
src/driver/main.cpp
+11
-38
src/driver/verify.cpp
src/driver/verify.cpp
+36
-0
src/driver/verify.hpp
src/driver/verify.hpp
+6
-0
No files found.
src/driver/argument_parser.hpp
View file @
5139b930
...
@@ -187,6 +187,13 @@ struct value_parser
...
@@ -187,6 +187,13 @@ struct value_parser
}
}
};
};
// version for std::optional object
template
<
class
T
>
struct
value_parser
<
std
::
optional
<
T
>>
{
static
T
apply
(
const
std
::
string
&
x
)
{
return
value_parser
<
T
>::
apply
(
x
);
}
};
struct
argument_parser
struct
argument_parser
{
{
struct
argument
struct
argument
...
...
src/driver/main.cpp
View file @
5139b930
...
@@ -540,22 +540,17 @@ struct params : command<params>
...
@@ -540,22 +540,17 @@ struct params : command<params>
struct
verify
:
command
<
verify
>
struct
verify
:
command
<
verify
>
{
{
compiler
c
;
compiler
c
;
// Set to -1. as nonsense initial value
std
::
optional
<
double
>
rms_tol
;
double
rms_tol
=
-
1.0
;
std
::
optional
<
double
>
atol
;
double
atol
=
-
1.0
;
std
::
optional
<
double
>
rtol
;
double
rtol
=
-
1.0
;
bool
per_instruction
=
false
;
bool
per_instruction
=
false
;
bool
reduce
=
false
;
bool
reduce
=
false
;
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
c
.
parse
(
ap
);
c
.
parse
(
ap
);
ap
(
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error (Default: 0.001)"
));
ap
(
rms_tol
,
{
"--rms-tol"
},
ap
.
help
(
"Tolerance for the RMS error"
));
ap
(
atol
,
ap
(
atol
,
{
"--atol"
},
ap
.
help
(
"Tolerance for the elementwise absolute difference"
));
{
"--atol"
},
ap
(
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference"
));
ap
.
help
(
"Tolerance for the elementwise absolute difference (Default: 0.001)"
));
ap
(
rtol
,
{
"--rtol"
},
ap
.
help
(
"Tolerance for the elementwise relative difference (Default: 0.001)"
));
ap
(
per_instruction
,
ap
(
per_instruction
,
{
"-i"
,
"--per-instruction"
},
{
"-i"
,
"--per-instruction"
},
ap
.
help
(
"Verify each instruction"
),
ap
.
help
(
"Verify each instruction"
),
...
@@ -572,33 +567,6 @@ struct verify : command<verify>
...
@@ -572,33 +567,6 @@ struct verify : command<verify>
auto
t
=
c
.
ct
.
get_target
();
auto
t
=
c
.
ct
.
get_target
();
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
auto
m
=
c
.
parameters
.
generate
(
p
,
t
,
true
,
c
.
l
.
batch
);
// TODO remove this and make the driver able to figure out datatype most used in the model
// then set the tolerances appropriately. Need to check here because c.to_fp16 only set
// after argument_parser.parse() is run. This code is complicated because there's not a
// good way to change the default tolerances after reading `--fp16` but before reading
// `--rms-tol`, `--atol`, and `--rtol`.
migraphx
::
verify
::
tolerance
tols
{};
if
(
c
.
to_fp16
)
{
tols
=
migraphx
::
verify
::
tolerance
{
8e-2
,
4e-2
,
4e-2
};
}
if
(
not
float_equal
(
this
->
rms_tol
,
-
1.0
))
{
tols
.
rms_tol
=
this
->
rms_tol
;
}
if
(
not
float_equal
(
this
->
atol
,
-
1.0
))
{
tols
.
atol
=
this
->
atol
;
}
if
(
not
float_equal
(
this
->
rtol
,
-
1.0
))
{
tols
.
rtol
=
this
->
rtol
;
}
std
::
cout
<<
"rms_tol: "
<<
tols
.
rms_tol
<<
std
::
endl
;
std
::
cout
<<
"atol: "
<<
tols
.
atol
<<
std
::
endl
;
std
::
cout
<<
"rtol: "
<<
tols
.
rtol
<<
std
::
endl
;
auto
quantize
=
precision
::
fp32
;
auto
quantize
=
precision
::
fp32
;
if
(
c
.
to_fp16
)
if
(
c
.
to_fp16
)
{
{
...
@@ -609,6 +577,11 @@ struct verify : command<verify>
...
@@ -609,6 +577,11 @@ struct verify : command<verify>
quantize
=
precision
::
int8
;
quantize
=
precision
::
int8
;
}
}
auto
tols
=
get_tolerances
(
p
,
quantize
,
rms_tol
,
atol
,
rtol
);
std
::
cout
<<
"rms_tol: "
<<
tols
.
rms_tol
<<
std
::
endl
;
std
::
cout
<<
"atol: "
<<
tols
.
atol
<<
std
::
endl
;
std
::
cout
<<
"rtol: "
<<
tols
.
rtol
<<
std
::
endl
;
if
(
per_instruction
)
if
(
per_instruction
)
{
{
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tols
);
verify_instructions
(
p
,
t
,
c
.
co
,
quantize
,
tols
);
...
...
src/driver/verify.cpp
View file @
5139b930
...
@@ -36,6 +36,42 @@ namespace migraphx {
...
@@ -36,6 +36,42 @@ namespace migraphx {
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify
::
tolerance
get_tolerances
(
const
program
&
p
,
precision
quantize
,
std
::
optional
<
double
>
rms_tol
,
std
::
optional
<
double
>
atol
,
std
::
optional
<
double
>
rtol
)
{
bool
has_fp16
=
any_of
(
p
.
get_modules
(),
[](
auto
&&
m
)
{
return
any_of
(
*
m
,
[](
auto
&&
ins
)
{
return
(
ins
.
get_shape
().
type
()
==
shape
::
half_type
);
});
});
migraphx
::
verify
::
tolerance
result
{};
if
(
has_fp16
or
quantize
==
precision
::
fp16
)
{
result
.
rms_tol
=
8e-2
;
result
.
atol
=
4e-2
;
result
.
rtol
=
4e-2
;
}
if
(
rms_tol
)
{
result
.
rms_tol
=
*
rms_tol
;
}
if
(
atol
)
{
result
.
atol
=
*
atol
;
}
if
(
rtol
)
{
result
.
rtol
=
*
rtol
;
}
return
result
;
}
std
::
vector
<
argument
>
run_ref
(
program
p
,
const
parameter_map
&
inputs
)
std
::
vector
<
argument
>
run_ref
(
program
p
,
const
parameter_map
&
inputs
)
{
{
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
...
...
src/driver/verify.hpp
View file @
5139b930
...
@@ -32,6 +32,12 @@ namespace migraphx {
...
@@ -32,6 +32,12 @@ namespace migraphx {
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
verify
::
tolerance
get_tolerances
(
const
program
&
p
,
precision
quantize
,
std
::
optional
<
double
>
rms_tol
,
std
::
optional
<
double
>
atol
,
std
::
optional
<
double
>
rtol
);
void
verify_program
(
const
std
::
string
&
name
,
void
verify_program
(
const
std
::
string
&
name
,
const
program
&
p
,
const
program
&
p
,
const
target
&
t
,
const
target
&
t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment