Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f787d5bd
Unverified
Commit
f787d5bd
authored
Aug 08, 2023
by
kahmed10
Committed by
GitHub
Aug 08, 2023
Browse files
int8 optimizations (#1973)
* add quant_dot fusion, clip literal opt
parent
a359d2c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
47 deletions
+71
-47
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+38
-30
test/rewrite_quantization_test.cpp
test/rewrite_quantization_test.cpp
+13
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+9
-5
No files found.
src/rewrite_quantization.cpp
View file @
f787d5bd
...
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
});
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
auto
s
=
add_zero_point
->
get_shape
();
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
...
...
src/simplify_algebra.cpp
View file @
f787d5bd
...
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
struct
find_conv_dot_horiz_fusion
...
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
// Check that non-axes match
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
axis
=
x
.
size
()
-
1
;
}
...
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
...
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
...
...
test/onnx/onnx_test.cpp
View file @
f787d5bd
...
...
@@ -4712,14 +4712,16 @@ TEST_CASE(quantizelinear_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
0
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
255
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_arg
,
max_arg
);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
...
@@ -4741,14 +4743,16 @@ TEST_CASE(quantizelinear_int32_test)
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
l0
,
l1_mbcast
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"round"
),
div
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
0
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
255
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
round
,
min_arg
,
max_arg
);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
...
@@ -4775,13 +4779,15 @@ TEST_CASE(quantizelinear_zero_point_test)
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_mbcast
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
-
128
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
127
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_arg
,
max_arg
);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...
...
@@ -4812,13 +4818,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
round
,
l2_bcast
);
auto
s
=
round
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
-
128
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
127
);
auto
min_arg
=
mm
->
add_literal
(
s
,
min_data
);
auto
max_arg
=
mm
->
add_literal
(
s
,
max_data
);
auto
clip
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
add
,
min_arg
,
max_arg
);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...
...
test/rewrite_quantization_test.cpp
View file @
f787d5bd
...
...
@@ -37,6 +37,17 @@
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
{
if
(
ins
.
name
()
==
"clip"
)
{
assert
(
ins
.
inputs
().
size
()
>
1
);
return
(
std
::
all_of
(
ins
.
inputs
().
begin
()
+
1
,
ins
.
inputs
().
end
(),
[](
auto
input
)
{
return
input
->
get_shape
().
scalar
();
}));
}
return
false
;
}
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
rewrite_quantization
{}});
}
...
...
@@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
EXPECT
(
eval
(
p1
)
==
eval
(
p2
));
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
// ensure clip literals created in quantized program are scalar
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
}
TEST_CASE
(
dequantizelinear
)
...
...
test/simplify_algebra_test.cpp
View file @
f787d5bd
...
...
@@ -2189,16 +2189,16 @@ TEST_CASE(simplify_split_between_add)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz
)
void
test_dot_horiz
(
migraphx
::
shape
::
type_t
type
,
const
std
::
string
&
dot_type
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_
type
,
{
3
,
2
,
2
}};
auto
s
=
migraphx
::
shape
{
type
,
{
3
,
2
,
2
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
b
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
b
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
...
...
@@ -2210,7 +2210,7 @@ TEST_CASE(simplify_dot_horiz)
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
concat
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
...
...
@@ -2221,6 +2221,10 @@ TEST_CASE(simplify_dot_horiz)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int32_type
,
"dot"
);
}
TEST_CASE
(
simplify_quant_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int8_type
,
"quant_dot"
);
}
TEST_CASE
(
simplify_dot_horiz_same_constant
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment