Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e61114a
"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a4bb0fdb132cde885801e37087153bbfb6a26b11"
Unverified
Commit
7e61114a
authored
Dec 12, 2023
by
Umang Yadav
Committed by
GitHub
Dec 12, 2023
Browse files
refactoring quantization passes (#2544)
parent
b742b528
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
26 additions
and
33 deletions
+26
-33
src/api/api.cpp
src/api/api.cpp
+3
-3
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+2
-2
src/include/migraphx/quantize_8bits.hpp
src/include/migraphx/quantize_8bits.hpp
+3
-3
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-1
src/quantization.cpp
src/quantization.cpp
+8
-10
src/quantize_8bits.cpp
src/quantize_8bits.cpp
+1
-5
test/quantization.cpp
test/quantization.cpp
+6
-7
tools/api/api.cpp
tools/api/api.cpp
+2
-2
No files found.
src/api/api.cpp
View file @
7e61114a
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
src/include/migraphx/quantization.hpp
View file @
7e61114a
...
@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
...
@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT
void
quantize_int8
(
program
&
prog
,
MIGRAPHX_EXPORT
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
const
std
::
unordered_set
<
std
::
string
>&
ins_names
=
{
"convolution"
});
"dot"
,
"convolution"
});
MIGRAPHX_EXPORT
void
MIGRAPHX_EXPORT
void
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
quantize_fp8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
);
...
...
src/include/migraphx/quantize_8bits.hpp
View file @
7e61114a
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string>
#include <string>
#include <unordered_set>
#include <vector>
#include <vector>
#include <functional>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -41,7 +42,7 @@ struct module;
...
@@ -41,7 +42,7 @@ struct module;
*/
*/
struct
MIGRAPHX_EXPORT
capture_arguments_pass
struct
MIGRAPHX_EXPORT
capture_arguments_pass
{
{
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
unordered_set
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>
f
{};
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>
f
{};
std
::
size_t
*
param_index
=
nullptr
;
std
::
size_t
*
param_index
=
nullptr
;
std
::
string
name
()
const
{
return
"capture_arguments"
;
}
std
::
string
name
()
const
{
return
"capture_arguments"
;
}
...
@@ -54,7 +55,6 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
...
@@ -54,7 +55,6 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
struct
MIGRAPHX_EXPORT
quantize_8bits_pass
struct
MIGRAPHX_EXPORT
quantize_8bits_pass
{
{
shape
::
type_t
precision
=
shape
::
int8_type
;
shape
::
type_t
precision
=
shape
::
int8_type
;
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
string
name
()
const
{
return
"quantize_8bits"
;
}
std
::
string
name
()
const
{
return
"quantize_8bits"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
...
...
src/py/migraphx_py.cpp
View file @
7e61114a
...
@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"prog"
),
py
::
arg
(
"prog"
),
py
::
arg
(
"t"
),
py
::
arg
(
"t"
),
py
::
arg
(
"calibration"
)
=
std
::
vector
<
migraphx
::
parameter_map
>
{},
py
::
arg
(
"calibration"
)
=
std
::
vector
<
migraphx
::
parameter_map
>
{},
py
::
arg
(
"ins_names"
)
=
std
::
vector
<
std
::
string
>
{
"dot"
,
"convolution"
});
py
::
arg
(
"ins_names"
)
=
std
::
unordered_set
<
std
::
string
>
{
"dot"
,
"convolution"
});
#ifdef HAVE_GPU
#ifdef HAVE_GPU
m
.
def
(
"allocate_gpu"
,
&
migraphx
::
gpu
::
allocate_gpu
,
py
::
arg
(
"s"
),
py
::
arg
(
"host"
)
=
false
);
m
.
def
(
"allocate_gpu"
,
&
migraphx
::
gpu
::
allocate_gpu
,
py
::
arg
(
"s"
),
py
::
arg
(
"host"
)
=
false
);
...
...
src/quantization.cpp
View file @
7e61114a
...
@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
...
@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
const
target
&
t
,
const
target
&
t
,
shape
::
type_t
precision
,
shape
::
type_t
precision
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
const
std
::
unordered_set
<
std
::
string
>&
ins_names
)
{
{
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
// avoid loss of precision.
...
@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
...
@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
}
}
run_passes
(
prog
,
run_passes
(
prog
,
{
quantize_8bits_pass
{
precision
,
ins_names
,
*
quant_8bit_params
},
{
quantize_8bits_pass
{
precision
,
*
quant_8bit_params
},
simplify_qdq
{},
simplify_qdq
{},
optimize_module
{},
optimize_module
{},
dead_code_elimination
{}});
dead_code_elimination
{}});
...
@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
...
@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
void
quantize_int8
(
program
&
prog
,
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
const
std
::
unordered_set
<
std
::
string
>&
ins_names
)
{
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
unordered_set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
op_names
!=
ins_names
)
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
}
...
@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
...
@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
std
::
cout
<<
"[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs
\n
"
;
"incorrect final outputs
\n
"
;
std
::
vector
<
std
::
string
>
supported_ins_names
;
std
::
unordered_set
<
std
::
string
>
supported_ins_names
;
auto
*
mm
=
prog
.
get_main_module
();
auto
*
mm
=
prog
.
get_main_module
();
for
(
auto
ins
:
iterator_for
(
*
mm
))
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
{
...
@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
...
@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
{
{
continue
;
continue
;
}
}
else
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
if
(
not
starts_with
(
ins
->
name
(),
"@"
))
{
{
supported_ins_names
.
push_back
(
ins
->
name
());
supported_ins_names
.
insert
(
ins
->
name
());
}
}
}
}
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
quantize_8bits
(
prog
,
t
,
shape
::
fp8e4m3fnuz_type
,
calibration
,
supported_ins_names
);
...
...
src/quantize_8bits.cpp
View file @
7e61114a
...
@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
...
@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
not
contains
(
ins_names
,
ins
->
name
()))
if
((
not
contains
(
ins_names
,
ins
->
name
()))
or
(
ins
->
name
()
==
"convert"
))
{
continue
;
}
if
(
ins
->
name
()
==
"convert"
)
{
{
continue
;
continue
;
}
}
...
...
test/quantization.cpp
View file @
7e61114a
...
@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
...
@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
...
@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
...
@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
...
@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
migraphx
::
run_passes
(
p
,
p
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
{
"dot"
},
quant_params
},
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
int8_type
,
quant_params
},
migraphx
::
dead_code_elimination
{}});
migraphx
::
dead_code_elimination
{}});
EXPECT
(
p
==
create_int8_quantized_prog
());
EXPECT
(
p
==
create_int8_quantized_prog
());
...
@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
...
@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
...
@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
test
::
throws
([
&
]
{
test
::
throws
([
&
]
{
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"add"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
});
});
}
}
...
@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
...
@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p
,
migraphx
::
run_passes
(
p
,
{
migraphx
::
quantize_8bits_pass
{
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
},
quant_params
}});
migraphx
::
shape
::
type_t
::
int8_type
,
quant_params
}});
optimize_prog_int8
(
p
);
optimize_prog_int8
(
p
);
auto
qp
=
create_int8_quantized_prog
();
auto
qp
=
create_int8_quantized_prog
();
...
@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
...
@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
p1
,
{
migraphx
::
capture_arguments_pass
{{
"convolution"
,
"dot"
},
{},
&
param_index
}});
migraphx
::
run_passes
(
p1
,
migraphx
::
run_passes
(
p1
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
migraphx
::
quantize_8bits_pass
{
migraphx
::
shape
::
type_t
::
int8_type
,
{
"convolution"
,
"dot"
},
quant_params
}});
quant_params
}});
optimize_prog_int8
(
p1
);
optimize_prog_int8
(
p1
);
...
...
tools/api/api.cpp
View file @
7e61114a
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
...
@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct
quantize_int8_options
struct
quantize_int8_options
{
{
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
parameter_map
>
calibration
=
{};
std
::
vector
<
std
::
string
>
op_names
=
{};
std
::
unordered_set
<
std
::
string
>
op_names
=
{};
};
};
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
void
add_op_name
(
quantize_int8_options
&
options
,
const
char
*
name
)
{
{
options
.
op_names
.
push_back
(
name
);
options
.
op_names
.
insert
(
name
);
}
}
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
void
add_calibration_data
(
quantize_int8_options
&
options
,
parameter_map
&
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment