Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db3c07fb
Unverified
Commit
db3c07fb
authored
Dec 12, 2023
by
Umang Yadav
Committed by
GitHub
Dec 12, 2023
Browse files
Add `--fp8` option to quantize models in FP8 inside `migraphx-driver` (#2535)
parent
aac4e950
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
127 additions
and
57 deletions
+127
-57
docs/dev/env_vars.rst
docs/dev/env_vars.rst
+1
-1
docs/driver/compile.rst
docs/driver/compile.rst
+3
-0
examples/migraphx/migraphx_driver/README.md
examples/migraphx/migraphx_driver/README.md
+1
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/driver/main.cpp
src/driver/main.cpp
+6
-0
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+2
-0
src/include/migraphx/quantize_8bits.hpp
src/include/migraphx/quantize_8bits.hpp
+8
-7
src/quantization.cpp
src/quantization.cpp
+55
-28
src/quantize_8bits.cpp
src/quantize_8bits.cpp
+21
-9
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+9
-3
test/quantization.cpp
test/quantization.cpp
+20
-8
No files found.
docs/dev/env_vars.rst
View file @
db3c07fb
...
@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
...
@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant.
Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
.. envvar:: MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module.
Print the quantization parameters in only the main module.
...
...
docs/driver/compile.rst
View file @
db3c07fb
...
@@ -38,3 +38,6 @@ Quantize for fp16
...
@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8
Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
examples/migraphx/migraphx_driver/README.md
View file @
db3c07fb
...
@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
...
@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
...
...
src/CMakeLists.txt
View file @
db3c07fb
...
@@ -81,7 +81,7 @@ add_library(migraphx
...
@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp
promote_literals.cpp
quantization.cpp
quantization.cpp
quantize_fp16.cpp
quantize_fp16.cpp
quantize_
int8
.cpp
quantize_
8bits
.cpp
reduce_dims.cpp
reduce_dims.cpp
register_op.cpp
register_op.cpp
register_target.cpp
register_target.cpp
...
...
src/driver/main.cpp
View file @
db3c07fb
...
@@ -445,6 +445,7 @@ struct compiler
...
@@ -445,6 +445,7 @@ struct compiler
compiler_target
ct
;
compiler_target
ct
;
compile_options
co
;
compile_options
co
;
bool
to_fp16
=
false
;
bool
to_fp16
=
false
;
bool
to_fp8
=
false
;
bool
to_int8
=
false
;
bool
to_int8
=
false
;
std
::
vector
<
std
::
string
>
fill0
;
std
::
vector
<
std
::
string
>
fill0
;
...
@@ -468,6 +469,7 @@ struct compiler
...
@@ -468,6 +469,7 @@ struct compiler
ap
.
set_value
(
true
));
ap
.
set_value
(
true
));
ap
(
to_fp16
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
true
));
ap
(
to_fp16
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
true
));
ap
(
to_int8
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
true
));
ap
(
to_int8
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
true
));
ap
(
to_fp8
,
{
"--fp8"
},
ap
.
help
(
"Quantize for fp8e4m3fnuz type"
),
ap
.
set_value
(
true
));
}
}
auto
params
(
const
program
&
p
)
auto
params
(
const
program
&
p
)
...
@@ -518,6 +520,10 @@ struct compiler
...
@@ -518,6 +520,10 @@ struct compiler
{
{
quantize_int8
(
p
,
t
,
{
host_params
(
p
)});
quantize_int8
(
p
,
t
,
{
host_params
(
p
)});
}
}
if
(
to_fp8
)
{
quantize_fp8
(
p
,
t
,
{
host_params
(
p
)});
}
p
.
compile
(
t
,
co
);
p
.
compile
(
t
,
co
);
l
.
save
(
p
);
l
.
save
(
p
);
return
p
;
return
p
;
...
...
src/include/migraphx/quantization.hpp
View file @
db3c07fb
...
@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
...
@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
"convolution"
});
"convolution"
});
MIGRAPHX_EXPORT
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/quantize_
int8
.hpp
→
src/include/migraphx/quantize_
8bits
.hpp
View file @
db3c07fb
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -21,8 +21,8 @@
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
INT8
_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
8BITS
_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
INT8
_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
8BITS
_HPP
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -37,7 +37,7 @@ struct program;
...
@@ -37,7 +37,7 @@ struct program;
struct
module
;
struct
module
;
/**
/**
* capture inputs of operators to be quantized to int8
* capture inputs of operators to be quantized to int8
or fp8
*/
*/
struct
MIGRAPHX_EXPORT
capture_arguments_pass
struct
MIGRAPHX_EXPORT
capture_arguments_pass
{
{
...
@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
...
@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
};
};
/**
/**
* quantize a program to int8
* quantize a program to int8
or fp8
*/
*/
struct
MIGRAPHX_EXPORT
quantize_
int8
_pass
struct
MIGRAPHX_EXPORT
quantize_
8bits
_pass
{
{
shape
::
type_t
precision
=
shape
::
int8_type
;
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
string
name
()
const
{
return
"quantize_
int8
"
;
}
std
::
string
name
()
const
{
return
"quantize_
8bits
"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/quantization.cpp
View file @
db3c07fb
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
...
@@ -45,7 +45,7 @@
...
@@ -45,7 +45,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
)
// This function is to convert any instructions specified in the input
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// from double or float to float16 by inserting a convert operator.
...
@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
...
@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes
(
prog
,
{
optimize_module
{},
quantize_fp16_pass
{
ins_names
},
optimize_module
{}});
run_passes
(
prog
,
{
optimize_module
{},
quantize_fp16_pass
{
ins_names
},
optimize_module
{}});
}
}
void
quantize_
int8
(
program
&
prog
,
void
quantize_
8bits
(
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
shape
::
type_t
precision
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision.
// avoid loss of precision.
run_passes
(
prog
,
{
optimize_module
{}});
run_passes
(
prog
,
{
optimize_module
{}});
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
int8_quan
t_params
=
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
quant_8bi
t_params
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
std
::
shared_ptr
<
std
::
vector
<
float
>>
max_abs_vals
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
std
::
shared_ptr
<
std
::
vector
<
float
>>
max_abs_vals
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
auto
calc_quant_params
=
[
int8_quant_params
,
max_abs_vals
,
&
t
](
std
::
size_t
ins_index
,
float
quantized_range
=
(
precision
==
shape
::
type_t
::
int8_type
)
?
127.0
:
240.0
;
std
::
vector
<
argument
>
args
)
{
auto
calc_quant_params
=
[
&
](
std
::
size_t
ins_index
,
std
::
vector
<
argument
>
args
)
{
std
::
pair
<
float
,
float
>
param_pair
{
64.0
f
,
0.0
f
};
std
::
pair
<
float
,
float
>
param_pair
{
64.0
f
,
0.0
f
};
// scale and shift is need for only int8 type, and we do not
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
// consider shift, so set shift to 0
...
@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
...
@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto
min_val
=
*
std
::
min_element
(
vec_val
.
begin
(),
vec_val
.
end
());
auto
min_val
=
*
std
::
min_element
(
vec_val
.
begin
(),
vec_val
.
end
());
auto
max_abs
=
std
::
max
(
std
::
fabs
(
max_val
),
std
::
fabs
(
min_val
));
auto
max_abs
=
std
::
max
(
std
::
fabs
(
max_val
),
std
::
fabs
(
min_val
));
max_abs_vals
->
at
(
ins_index
)
=
std
::
max
(
max_abs_vals
->
at
(
ins_index
),
max_abs
);
max_abs_vals
->
at
(
ins_index
)
=
std
::
max
(
max_abs_vals
->
at
(
ins_index
),
max_abs
);
// if all values are 0, no need to do scaling
// if all values are 0, no need to do scaling
if
(
max_abs_vals
->
at
(
ins_index
)
==
0.0
f
)
if
(
float_equal
(
max_abs_vals
->
at
(
ins_index
)
,
0.0
f
)
)
{
{
param_pair
.
first
=
1.0
f
;
param_pair
.
first
=
1.0
f
;
}
}
else
else
{
{
param_pair
.
first
=
127.0
f
/
max_abs_vals
->
at
(
ins_index
);
param_pair
.
first
=
quantized_range
/
max_abs_vals
->
at
(
ins_index
);
}
}
int8_quan
t_params
->
at
(
ins_index
)
=
param_pair
;
quant_8bi
t_params
->
at
(
ins_index
)
=
param_pair
;
};
};
// pass to add capture argument op
// pass to add capture argument op
std
::
size_t
param_num
=
0
;
std
::
size_t
param_num
=
0
;
run_passes
(
prog
,
{
capture_arguments_pass
{
ins_names
,
calc_quant_params
,
&
param_num
}});
run_passes
(
prog
,
{
capture_arguments_pass
{
ins_names
,
calc_quant_params
,
&
param_num
}});
int8_quan
t_params
->
resize
(
param_num
,
std
::
pair
<
float
,
float
>
(
64.0
f
,
0.0
f
));
quant_8bi
t_params
->
resize
(
param_num
,
std
::
pair
<
float
,
float
>
(
64.0
f
,
0.0
f
));
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
// use the calibration data to compute the quantization scale
// use the calibration data to compute the quantization scale
...
@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
...
@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
}
}
// print the quantization parameters in only the main module
// print the quantization parameters in only the main module
if
(
enabled
(
MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
{}))
if
(
enabled
(
MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
{}))
{
{
for
(
std
::
size_t
i
=
0
;
i
<
int8_quan
t_params
->
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
quant_8bi
t_params
->
size
();
++
i
)
{
{
auto
param
=
int8_quan
t_params
->
at
(
i
);
auto
param
=
quant_8bi
t_params
->
at
(
i
);
std
::
cout
<<
"ins_index = "
<<
i
<<
", scale = "
<<
param
.
first
std
::
cout
<<
"ins_index = "
<<
i
<<
", scale = "
<<
param
.
first
<<
", shift = "
<<
param
.
second
<<
std
::
endl
;
<<
", shift = "
<<
param
.
second
<<
std
::
endl
;
}
}
...
@@ -146,11 +138,46 @@ void quantize_int8(program& prog,
...
@@ -146,11 +138,46 @@ void quantize_int8(program& prog,
}
}
run_passes
(
prog
,
run_passes
(
prog
,
{
quantize_
int8
_pass
{
ins_names
,
*
int8_quan
t_params
},
{
quantize_
8bits
_pass
{
precision
,
ins_names
,
*
quant_8bi
t_params
},
simplify_qdq
{},
simplify_qdq
{},
optimize_module
{},
optimize_module
{},
dead_code_elimination
{}});
dead_code_elimination
{}});
}
}
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
quantize_8bits
(
prog
,
t
,
shape
::
int8_type
,
calibration
,
ins_names
);
}
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
)
{
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
std
::
vector
<
std
::
string
>
supported_ins_names
;
auto
*
mm
=
prog
.
get_main_module
();
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
if
(
ins
->
name
()
==
"convert"
)
{
continue
;
}
else
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
{
supported_ins_names
.
push_back
(
ins
->
name
());
}
}
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/quantize_
int8
.cpp
→
src/quantize_
8bits
.cpp
View file @
db3c07fb
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
...
@@ -41,8 +41,6 @@
...
@@ -41,8 +41,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
)
static
std
::
vector
<
shape
::
type_t
>&
get_quantizable_type
()
static
std
::
vector
<
shape
::
type_t
>&
get_quantizable_type
()
{
{
static
std
::
vector
<
shape
::
type_t
>
quantable_types
=
{
static
std
::
vector
<
shape
::
type_t
>
quantable_types
=
{
...
@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
...
@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return
quantable_types
;
return
quantable_types
;
}
}
void
quantize_
int8
_pass
::
apply
(
module
&
m
)
const
// NOLINT
void
quantize_
8bits
_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
{
const
auto
&
quantizable_types
=
get_quantizable_type
();
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
...
@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
...
@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto
input
=
ins
->
inputs
().
front
();
auto
input
=
ins
->
inputs
().
front
();
auto
s
=
input
->
get_shape
();
auto
s
=
input
->
get_shape
();
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
shape
::
int8_type
)
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
precision
)
{
{
auto
zero_point
=
m
.
add_literal
(
static_cast
<
int8_t
>
(
param
.
second
));
auto
zero_point
=
m
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
precision
},
{
param
.
second
}});
auto
scale
=
m
.
add_literal
(
literal
({
s
.
type
()},
{
1.0
f
/
param
.
first
}));
auto
scale
=
m
.
add_literal
(
literal
({
s
.
type
()},
{
1.0
f
/
param
.
first
}));
const
auto
&
lens
=
s
.
lens
();
const
auto
&
lens
=
s
.
lens
();
scale
=
scale
=
...
@@ -87,20 +86,33 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
...
@@ -87,20 +86,33 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void
capture_arguments_pass
::
apply
(
module
&
m
)
const
// NOLINT
void
capture_arguments_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
{
assert
(
param_index
!=
nullptr
);
assert
(
param_index
!=
nullptr
);
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
not
contains
(
ins_names
,
ins
->
name
()))
if
(
not
contains
(
ins_names
,
ins
->
name
()))
{
{
continue
;
continue
;
}
}
if
(
ins
->
name
()
==
"convert"
)
{
continue
;
}
auto
inputs
=
ins
->
inputs
();
auto
inputs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
new_args
;
std
::
vector
<
instruction_ref
>
new_args
;
for
(
auto
input
:
inputs
)
for
(
auto
input
:
inputs
)
{
if
(
contains
(
quantizable_types
,
input
->
get_shape
().
type
()))
{
{
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
new_args
.
push_back
(
new_in
);
new_args
.
push_back
(
new_in
);
}
}
else
{
new_args
.
push_back
(
input
);
}
}
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
}
}
}
...
...
src/simplify_qdq.cpp
View file @
db3c07fb
...
@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
...
@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool
diff_shapes_equal_vals
=
false
;
bool
diff_shapes_equal_vals
=
false
;
visit_all
(
ins1
->
get_literal
(),
ins2
->
get_literal
())([
&
](
const
auto
l1
,
const
auto
l2
)
{
visit_all
(
ins1
->
get_literal
(),
ins2
->
get_literal
())([
&
](
const
auto
l1
,
const
auto
l2
)
{
diff_shapes_equal_vals
=
diff_shapes_equal_vals
=
std
::
all_of
(
std
::
all_of
(
l1
.
begin
()
+
1
,
l1
.
begin
()
+
1
,
l1
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
})
and
l1
.
end
(),
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
});
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
})
and
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
});
});
});
return
(
x
==
y
)
or
diff_shapes_equal_vals
;
return
(
x
==
y
)
or
diff_shapes_equal_vals
;
...
...
test/quantization.cpp
View file @
db3c07fb
...
@@ -30,7 +30,7 @@
...
@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
EXPECT
(
p
==
qp
);
EXPECT
(
p
==
qp
);
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto
p
=
create_program
();
auto
p
=
create_program
();
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
test
::
throws
([
&
]
{
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"add"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"add"
},
quant_params
}});
});
});
}
}
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph)
...
@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph)
std
::
size_t
param_index
=
0
;
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
,
"dot"
},
quant_params
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
,
"dot"
},
quant_params
}});
optimize_prog_int8
(
p1
);
optimize_prog_int8
(
p1
);
auto
p2
=
create_int8_program
();
auto
p2
=
create_int8_program
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment