Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e61114a
Unverified
Commit
7e61114a
authored
Dec 12, 2023
by
Umang Yadav
Committed by
GitHub
Dec 12, 2023
Browse files
refactoring quantization passes (#2544)
parent
b742b528
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
26 additions
and
33 deletions
+26
-33
src/api/api.cpp
src/api/api.cpp
+3
-3
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+2
-2
src/include/migraphx/quantize_8bits.hpp
src/include/migraphx/quantize_8bits.hpp
+3
-3
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-1
src/quantization.cpp
src/quantization.cpp
+8
-10
src/quantize_8bits.cpp
src/quantize_8bits.cpp
+1
-5
test/quantization.cpp
test/quantization.cpp
+6
-7
tools/api/api.cpp
tools/api/api.cpp
+2
-2
No files found.
src/api/api.cpp
View file @
7e61114a
...
@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
src/include/migraphx/quantization.hpp
View file @
7e61114a
...
@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
...
@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT
void
quantize_int8
(
program
&
prog
,
MIGRAPHX_EXPORT
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
const
std
::
unordered_set
<
std
::
string
>&
ins_names
=
{
"convolution"
});
"dot"
,
"convolution"
});
MIGRAPHX_EXPORT
void
MIGRAPHX_EXPORT
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
...
...
src/include/migraphx/quantize_8bits.hpp
View file @
7e61114a
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string>
#include <string>
#include <unordered_set>
#include <vector>
#include <vector>
#include <functional>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -41,7 +42,7 @@ struct module;
...
@@ -41,7 +42,7 @@ struct module;
*/
*/
struct
MIGRAPHX_EXPORT
capture_arguments_pass
struct
MIGRAPHX_EXPORT
capture_arguments_pass
{
{
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
unordered_set
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>
f
{};
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>
f
{};
std
::
size_t
*
param_index
=
nullptr
;
std
::
size_t
*
param_index
=
nullptr
;
std
::
string
name
()
const
{
return
"capture_arguments"
;
}
std
::
string
name
()
const
{
return
"capture_arguments"
;
}
...
@@ -53,8 +54,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
...
@@ -53,8 +54,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
*/
*/
struct
MIGRAPHX_EXPORT
quantize_8bits_pass
struct
MIGRAPHX_EXPORT
quantize_8bits_pass
{
{
shape
::
type_t
precision
=
shape
::
int8_type
;
shape
::
type_t
precision
=
shape
::
int8_type
;
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
string
name
()
const
{
return
"quantize_8bits"
;
}
std
::
string
name
()
const
{
return
"quantize_8bits"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
...
...
src/py/migraphx_py.cpp
View file @
7e61114a
...
@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"prog"
),
py
::
arg
(
"prog"
),
py
::
arg
(
"t"
),
py
::
arg
(
"t"
),
py
::
arg
(
"calibration"
)
=
std
::
vector
<
migraphx
::
parameter_map
>
{},
py
::
arg
(
"calibration"
)
=
std
::
vector
<
migraphx
::
parameter_map
>
{},
py
::
arg
(
"ins_names"
)
=
std
::
vector
<
std
::
string
>
{
"dot"
,
"convolution"
});
py
::
arg
(
"ins_names"
)
=
std
::
unordered_set
<
std
::
string
>
{
"dot"
,
"convolution"
});
#ifdef HAVE_GPU
#ifdef HAVE_GPU
m
.
def
(
"allocate_gpu"
,
&
migraphx
::
gpu
::
allocate_gpu
,
py
::
arg
(
"s"
),
py
::
arg
(
"host"
)
=
false
);
m
.
def
(
"allocate_gpu"
,
&
migraphx
::
gpu
::
allocate_gpu
,
py
::
arg
(
"s"
),
py
::
arg
(
"host"
)
=
false
);
...
...
src/quantization.cpp
View file @
7e61114a
...
@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
...
@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
const
target
&
t
,
const
target
&
t
,
shape
::
type_t
precision
,
shape
::
type_t
precision
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
const
std
::
unordered_set
<
std
::
string
>&
ins_names
)
{
{
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
// avoid loss of precision.
...
@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
...
@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
}
}
run_passes
(
prog
,
run_passes
(
prog
,
{
quantize_8bits_pass
{
precision
,
ins_names
,
*
quant_8bit_params
},
{
quantize_8bits_pass
{
precision
,
*
quant_8bit_params
},
simplify_qdq
{},
simplify_qdq
{},
optimize_module
{},
optimize_module
{},
dead_code_elimination
{}});
dead_code_elimination
{}});
...
@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
...
@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
void
quantize_int8
(
program
&
prog
,
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
const
std
::
unordered_set
<
std
::
string
>&
ins_names
)
{
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
unordered_set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
op_names
!=
ins_names
)
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
}
...
@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
...
@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
"incorrect final outputs
\n
"
;
std
::
vector
<
std
::
string
>
supported_ins_names
;
std
::
unordered_set
<
std
::
string
>
supported_ins_names
;
auto
*
mm
=
prog
.
get_main_module
();
auto
*
mm
=
prog
.
get_main_module
();
for
(
auto
ins
:
iterator_for
(
*
mm
))
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
{
...
@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
...
@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
{
{
continue
;
continue
;
}
}
else
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
{
{
supported_ins_names
.
push_back
(
ins
->
name
());
supported_ins_names
.
insert
(
ins
->
name
());
}
}
}
}
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
...
...
src/quantize_8bits.cpp
View file @
7e61114a
...
@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
...
@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
not
contains
(
ins_names
,
ins
->
name
()))
if
((
not
contains
(
ins_names
,
ins
->
name
()))
or
(
ins
->
name
()
==
"convert"
))
{
continue
;
}
if
(
ins
->
name
()
==
"convert"
)
{
{
continue
;
continue
;
}
}
...
...
test/quantization.cpp
View file @
7e61114a
...
@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
...
@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
...
@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
...
@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
...
@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
...
@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
...
@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
...
@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
test
::
throws
([
&
]
{
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"add"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
});
});
}
}
...
@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
...
@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
...
@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
,
"dot"
},
quant_params
}});
quant_params
}});
optimize_prog_int8
(
p1
);
optimize_prog_int8
(
p1
);
...
...
tools/api/api.cpp
View file @
7e61114a
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment