Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db3c07fb
Unverified
Commit
db3c07fb
authored
Dec 12, 2023
by
Umang Yadav
Committed by
GitHub
Dec 12, 2023
Browse files
Add `--fp8` option to quantize models in FP8 inside `migraphx-driver` (#2535)
parent
aac4e950
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
127 additions
and
57 deletions
+127
-57
docs/dev/env_vars.rst
docs/dev/env_vars.rst
+1
-1
docs/driver/compile.rst
docs/driver/compile.rst
+3
-0
examples/migraphx/migraphx_driver/README.md
examples/migraphx/migraphx_driver/README.md
+1
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/driver/main.cpp
src/driver/main.cpp
+6
-0
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+2
-0
src/include/migraphx/quantize_8bits.hpp
src/include/migraphx/quantize_8bits.hpp
+8
-7
src/quantization.cpp
src/quantization.cpp
+55
-28
src/quantize_8bits.cpp
src/quantize_8bits.cpp
+21
-9
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+9
-3
test/quantization.cpp
test/quantization.cpp
+20
-8
No files found.
docs/dev/env_vars.rst
View file @
db3c07fb
...
...
@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
.. envvar:: MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module.
...
...
docs/driver/compile.rst
View file @
db3c07fb
...
...
@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
examples/migraphx/migraphx_driver/README.md
View file @
db3c07fb
...
...
@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
...
...
src/CMakeLists.txt
View file @
db3c07fb
...
...
@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_
int8
.cpp
quantize_
8bits
.cpp
reduce_dims.cpp
register_op.cpp
register_target.cpp
...
...
src/driver/main.cpp
View file @
db3c07fb
...
...
@@ -445,6 +445,7 @@ struct compiler
compiler_target
ct
;
compile_options
co
;
bool
to_fp16
=
false
;
bool
to_fp8
=
false
;
bool
to_int8
=
false
;
std
::
vector
<
std
::
string
>
fill0
;
...
...
@@ -468,6 +469,7 @@ struct compiler
ap
.
set_value
(
true
));
ap
(
to_fp16
,
{
"--fp16"
},
ap
.
help
(
"Quantize for fp16"
),
ap
.
set_value
(
true
));
ap
(
to_int8
,
{
"--int8"
},
ap
.
help
(
"Quantize for int8"
),
ap
.
set_value
(
true
));
ap
(
to_fp8
,
{
"--fp8"
},
ap
.
help
(
"Quantize for fp8e4m3fnuz type"
),
ap
.
set_value
(
true
));
}
auto
params
(
const
program
&
p
)
...
...
@@ -518,6 +520,10 @@ struct compiler
{
quantize_int8
(
p
,
t
,
{
host_params
(
p
)});
}
if
(
to_fp8
)
{
quantize_fp8
(
p
,
t
,
{
host_params
(
p
)});
}
p
.
compile
(
t
,
co
);
l
.
save
(
p
);
return
p
;
...
...
src/include/migraphx/quantization.hpp
View file @
db3c07fb
...
...
@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
"convolution"
});
MIGRAPHX_EXPORT
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/quantize_
int8
.hpp
→
src/include/migraphx/quantize_
8bits
.hpp
View file @
db3c07fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
INT8
_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
INT8
_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
8BITS
_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_
8BITS
_HPP
#include <string>
#include <vector>
...
...
@@ -37,7 +37,7 @@ struct program;
struct
module
;
/**
* capture inputs of operators to be quantized to int8
* capture inputs of operators to be quantized to int8
or fp8
*/
struct
MIGRAPHX_EXPORT
capture_arguments_pass
{
...
...
@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
};
/**
* quantize a program to int8
* quantize a program to int8
or fp8
*/
struct
MIGRAPHX_EXPORT
quantize_
int8
_pass
struct
MIGRAPHX_EXPORT
quantize_
8bits
_pass
{
shape
::
type_t
precision
=
shape
::
int8_type
;
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
string
name
()
const
{
return
"quantize_
int8
"
;
}
std
::
string
name
()
const
{
return
"quantize_
8bits
"
;
}
void
apply
(
module
&
m
)
const
;
};
...
...
src/quantization.cpp
View file @
db3c07fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
...
...
@@ -45,7 +45,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
)
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
...
...
@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes
(
prog
,
{
optimize_module
{},
quantize_fp16_pass
{
ins_names
},
optimize_module
{}});
}
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
void
quantize_8bits
(
program
&
prog
,
const
target
&
t
,
shape
::
type_t
precision
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes
(
prog
,
{
optimize_module
{}});
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
int8_quan
t_params
=
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
quant_8bi
t_params
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
std
::
shared_ptr
<
std
::
vector
<
float
>>
max_abs_vals
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
auto
calc_quant_params
=
[
int8_quant_params
,
max_abs_vals
,
&
t
](
std
::
size_t
ins_index
,
std
::
vector
<
argument
>
args
)
{
float
quantized_range
=
(
precision
==
shape
::
type_t
::
int8_type
)
?
127.0
:
240.0
;
auto
calc_quant_params
=
[
&
](
std
::
size_t
ins_index
,
std
::
vector
<
argument
>
args
)
{
std
::
pair
<
float
,
float
>
param_pair
{
64.0
f
,
0.0
f
};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
...
...
@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto
min_val
=
*
std
::
min_element
(
vec_val
.
begin
(),
vec_val
.
end
());
auto
max_abs
=
std
::
max
(
std
::
fabs
(
max_val
),
std
::
fabs
(
min_val
));
max_abs_vals
->
at
(
ins_index
)
=
std
::
max
(
max_abs_vals
->
at
(
ins_index
),
max_abs
);
// if all values are 0, no need to do scaling
if
(
max_abs_vals
->
at
(
ins_index
)
==
0.0
f
)
if
(
float_equal
(
max_abs_vals
->
at
(
ins_index
)
,
0.0
f
)
)
{
param_pair
.
first
=
1.0
f
;
}
else
{
param_pair
.
first
=
127.0
f
/
max_abs_vals
->
at
(
ins_index
);
param_pair
.
first
=
quantized_range
/
max_abs_vals
->
at
(
ins_index
);
}
int8_quan
t_params
->
at
(
ins_index
)
=
param_pair
;
quant_8bi
t_params
->
at
(
ins_index
)
=
param_pair
;
};
// pass to add capture argument op
std
::
size_t
param_num
=
0
;
run_passes
(
prog
,
{
capture_arguments_pass
{
ins_names
,
calc_quant_params
,
&
param_num
}});
int8_quan
t_params
->
resize
(
param_num
,
std
::
pair
<
float
,
float
>
(
64.0
f
,
0.0
f
));
quant_8bi
t_params
->
resize
(
param_num
,
std
::
pair
<
float
,
float
>
(
64.0
f
,
0.0
f
));
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
// use the calibration data to compute the quantization scale
...
...
@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
}
// print the quantization parameters in only the main module
if
(
enabled
(
MIGRAPHX_
INT8
_QUANTIZATION_PARAMS
{}))
if
(
enabled
(
MIGRAPHX_
8BITS
_QUANTIZATION_PARAMS
{}))
{
for
(
std
::
size_t
i
=
0
;
i
<
int8_quan
t_params
->
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
quant_8bi
t_params
->
size
();
++
i
)
{
auto
param
=
int8_quan
t_params
->
at
(
i
);
auto
param
=
quant_8bi
t_params
->
at
(
i
);
std
::
cout
<<
"ins_index = "
<<
i
<<
", scale = "
<<
param
.
first
<<
", shift = "
<<
param
.
second
<<
std
::
endl
;
}
...
...
@@ -146,11 +138,46 @@ void quantize_int8(program& prog,
}
run_passes
(
prog
,
{
quantize_
int8
_pass
{
ins_names
,
*
int8_quan
t_params
},
{
quantize_
8bits
_pass
{
precision
,
ins_names
,
*
quant_8bi
t_params
},
simplify_qdq
{},
optimize_module
{},
dead_code_elimination
{}});
}
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
quantize_8bits
(
prog
,
t
,
shape
::
int8_type
,
calibration
,
ins_names
);
}
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
)
{
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
std
::
vector
<
std
::
string
>
supported_ins_names
;
auto
*
mm
=
prog
.
get_main_module
();
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
if
(
ins
->
name
()
==
"convert"
)
{
continue
;
}
else
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
{
supported_ins_names
.
push_back
(
ins
->
name
());
}
}
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/quantize_
int8
.cpp
→
src/quantize_
8bits
.cpp
View file @
db3c07fb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
...
...
@@ -41,8 +41,6 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
)
static
std
::
vector
<
shape
::
type_t
>&
get_quantizable_type
()
{
static
std
::
vector
<
shape
::
type_t
>
quantable_types
=
{
...
...
@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return
quantable_types
;
}
void
quantize_
int8
_pass
::
apply
(
module
&
m
)
const
// NOLINT
void
quantize_
8bits
_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
...
...
@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto
input
=
ins
->
inputs
().
front
();
auto
s
=
input
->
get_shape
();
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
shape
::
int8_type
)
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
precision
)
{
auto
zero_point
=
m
.
add_literal
(
static_cast
<
int8_t
>
(
param
.
second
));
auto
zero_point
=
m
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
precision
},
{
param
.
second
}});
auto
scale
=
m
.
add_literal
(
literal
({
s
.
type
()},
{
1.0
f
/
param
.
first
}));
const
auto
&
lens
=
s
.
lens
();
scale
=
...
...
@@ -87,19 +86,32 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void
capture_arguments_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
assert
(
param_index
!=
nullptr
);
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
ins_names
,
ins
->
name
()))
{
continue
;
}
if
(
ins
->
name
()
==
"convert"
)
{
continue
;
}
auto
inputs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
new_args
;
for
(
auto
input
:
inputs
)
{
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
new_args
.
push_back
(
new_in
);
if
(
contains
(
quantizable_types
,
input
->
get_shape
().
type
()))
{
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
new_args
.
push_back
(
new_in
);
}
else
{
new_args
.
push_back
(
input
);
}
}
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
...
...
src/simplify_qdq.cpp
View file @
db3c07fb
...
...
@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool
diff_shapes_equal_vals
=
false
;
visit_all
(
ins1
->
get_literal
(),
ins2
->
get_literal
())([
&
](
const
auto
l1
,
const
auto
l2
)
{
diff_shapes_equal_vals
=
std
::
all_of
(
l1
.
begin
()
+
1
,
l1
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
})
and
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
});
std
::
all_of
(
l1
.
begin
()
+
1
,
l1
.
end
(),
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
})
and
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
});
});
return
(
x
==
y
)
or
diff_shapes_equal_vals
;
...
...
test/quantization.cpp
View file @
db3c07fb
...
...
@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
...
...
@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
EXPECT
(
p
==
qp
);
...
...
@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
...
...
@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
{
"dot"
},
quant_params
},
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
optimize_prog_int8
(
p
);
...
...
@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
...
...
@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto
p
=
create_program
();
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"add"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"add"
},
quant_params
}});
});
}
...
...
@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
{{
0.1
f
,
0.0
f
},
{
0.1
f
,
0.0
f
}};
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
},
quant_params
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
...
...
@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph)
std
::
size_t
param_index
=
0
;
migraphx
::
run_passes
(
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_int8_pass
{{
"convolution"
,
"dot"
},
quant_params
}});
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
,
"dot"
},
quant_params
}});
optimize_prog_int8
(
p1
);
auto
p2
=
create_int8_program
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment