Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f787d5bd
Unverified
Commit
f787d5bd
authored
Aug 08, 2023
by
kahmed10
Committed by
GitHub
Aug 08, 2023
Browse files
int8 optimizations (#1973)
* add quant_dot fusion, clip literal opt
parent
a359d2c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
47 deletions
+71
-47
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+38
-30
test/rewrite_quantization_test.cpp
test/rewrite_quantization_test.cpp
+13
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+9
-5
No files found.
src/rewrite_quantization.cpp
View file @
f787d5bd
...
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
});
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
auto
s
=
add_zero_point
->
get_shape
();
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
...
...
src/simplify_algebra.cpp
View file @
f787d5bd
...
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
struct
find_conv_dot_horiz_fusion
...
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
// Check that non-axes match
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
axis
=
x
.
size
()
-
1
;
}
...
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
...
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
...
...
test/onnx/onnx_test.cpp
View file @
f787d5bd
...
...
@@ -4712,14 +4712,16 @@ TEST_CASE(quantizelinear_test)
auto
l1
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
l1_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
}}}),
l1
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
0
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
255
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_arg
,
max_arg
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
0
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
255
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_mbcast
,
max_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
uint8_type
)}}),
...
...
@@ -4741,14 +4743,16 @@ TEST_CASE(quantizelinear_int32_test)
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l0
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
0
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
255
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_arg
,
max_arg
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
0
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
255
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_mbcast
,
max_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
uint8_type
)}}),
...
...
@@ -4775,13 +4779,15 @@ TEST_CASE(quantizelinear_zero_point_test)
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l2_mbcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_mbcast
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
-
128
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
127
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_arg
,
max_arg
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_mbcast
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
-
128
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
127
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_mbcast
,
max_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
int8_type
)}}),
...
...
@@ -4812,13 +4818,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
float_type
)}}),
l2_bcast
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_bcast
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
-
128
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
127
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_arg
,
max_arg
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_bcast
);
auto
s
=
round
->
get_shape
();
auto
min_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
-
128
}});
auto
max_arg
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
s
.
type
()},
{
127
}});
auto
min_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_mbcast
,
max_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
migraphx
::
shape
::
int8_type
)}}),
...
...
test/rewrite_quantization_test.cpp
View file @
f787d5bd
...
...
@@ -37,6 +37,17 @@
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
{
if
(
ins
.
name
()
==
"clip"
)
{
assert
(
ins
.
inputs
().
size
()
>
1
);
return
(
std
::
all_of
(
ins
.
inputs
().
begin
()
+
1
,
ins
.
inputs
().
end
(),
[](
auto
input
)
{
return
input
->
get_shape
().
scalar
();
}));
}
return
false
;
}
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
rewrite_quantization
{}});
}
...
...
@@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
EXPECT
(
eval
(
p1
)
==
eval
(
p2
));
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
// ensure clip literals created in quantized program are scalar
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
}
TEST_CASE
(
dequantizelinear
)
...
...
test/simplify_algebra_test.cpp
View file @
f787d5bd
...
...
@@ -2189,16 +2189,16 @@ TEST_CASE(simplify_split_between_add)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz
)
void
test_dot_horiz
(
migraphx
::
shape
::
type_t
type
,
const
std
::
string
&
dot_type
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_
type
,
{
3
,
2
,
2
}};
auto
s
=
migraphx
::
shape
{
type
,
{
3
,
2
,
2
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
b
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
b
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
...
...
@@ -2210,7 +2210,7 @@ TEST_CASE(simplify_dot_horiz)
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
concat
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
...
...
@@ -2221,6 +2221,10 @@ TEST_CASE(simplify_dot_horiz)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int32_type
,
"dot"
);
}
TEST_CASE
(
simplify_quant_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int8_type
,
"quant_dot"
);
}
TEST_CASE
(
simplify_dot_horiz_same_constant
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment