Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f787d5bd
Unverified
Commit
f787d5bd
authored
Aug 08, 2023
by
kahmed10
Committed by
GitHub
Aug 08, 2023
Browse files
int8 optimizations (#1973)
* add quant_dot fusion, clip literal opt
parent
a359d2c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
47 deletions
+71
-47
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+38
-30
test/rewrite_quantization_test.cpp
test/rewrite_quantization_test.cpp
+13
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+9
-5
No files found.
src/rewrite_quantization.cpp
View file @
f787d5bd
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
});
});
auto
s
=
add_zero_point
->
get_shape
();
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
m
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
}
...
...
src/simplify_algebra.cpp
View file @
f787d5bd
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
return
false
;
// Check that non-axes match
// Check that non-axes match
int
axis
=
1
;
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
{
axis
=
x
.
size
()
-
1
;
axis
=
x
.
size
()
-
1
;
}
}
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
return
;
auto
&&
name
=
(
*
start
)
->
name
();
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
return
;
auto
op
=
(
*
start
)
->
get_operator
();
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
int
group
=
1
;
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
axis
=
1
;
int
concat_axis
=
0
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
concat_axis
=
axis
;
...
...
test/onnx/onnx_test.cpp
View file @
f787d5bd
...
@@ -4712,14 +4712,16 @@ TEST_CASE(quantizelinear_test)
...
@@ -4712,14 +4712,16 @@ TEST_CASE(quantizelinear_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast =
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
std::vector<int> max_data(s.elements(), 255);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_arg = mm->add_literal(s, min_data);
auto min_mbcast =
auto max_arg = mm->add_literal(s, max_data);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
@@ -4741,14 +4743,16 @@ TEST_CASE(quantizelinear_int32_test)
...
@@ -4741,14 +4743,16 @@ TEST_CASE(quantizelinear_int32_test)
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape();
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), 0);
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}});
std::vector<int> max_data(s.elements(), 255);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_arg = mm->add_literal(s, min_data);
auto min_mbcast =
auto max_arg = mm->add_literal(s, max_data);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction(
mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
@@ -4775,13 +4779,15 @@ TEST_CASE(quantizelinear_zero_point_test)
...
@@ -4775,13 +4779,15 @@ TEST_CASE(quantizelinear_zero_point_test)
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_mbcast);
l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto s = round->get_shape();
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
std::vector<int> max_data(s.elements(), 127);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_arg = mm->add_literal(s, min_data);
auto min_mbcast =
auto max_arg = mm->add_literal(s, max_data);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...
@@ -4812,13 +4818,15 @@ migraphx::program make_quantizelinear_axis_prog()
...
@@ -4812,13 +4818,15 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l2_bcast);
l2_bcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto s = round->get_shape();
auto s = round->get_shape();
std::vector<int> min_data(s.elements(), -128);
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}});
std::vector<int> max_data(s.elements(), 127);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_arg = mm->add_literal(s, min_data);
auto min_mbcast =
auto max_arg = mm->add_literal(s, max_data);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction(
mm->add_instruction(
migraphx::make_op("convert",
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...
...
test/rewrite_quantization_test.cpp
View file @
f787d5bd
...
@@ -37,6 +37,17 @@
...
@@ -37,6 +37,17 @@
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_quantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"quantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_dequantizelinear
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"dequantizelinear"
;
}
bool
is_clip_scalar
(
migraphx
::
instruction
&
ins
)
{
if
(
ins
.
name
()
==
"clip"
)
{
assert
(
ins
.
inputs
().
size
()
>
1
);
return
(
std
::
all_of
(
ins
.
inputs
().
begin
()
+
1
,
ins
.
inputs
().
end
(),
[](
auto
input
)
{
return
input
->
get_shape
().
scalar
();
}));
}
return
false
;
}
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
rewrite_quantization
{}});
}
void
run_pass
(
migraphx
::
module
&
m
)
{
migraphx
::
run_passes
(
m
,
{
migraphx
::
rewrite_quantization
{}});
}
...
@@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
...
@@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
EXPECT
(
eval
(
p1
)
==
eval
(
p2
));
EXPECT
(
eval
(
p1
)
==
eval
(
p2
));
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
any_of
(
*
p1
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
EXPECT
(
none_of
(
*
p2
.
get_main_module
(),
&
is_quantizelinear
));
// ensure clip literals created in quantized program are scalar
EXPECT
(
any_of
(
*
p2
.
get_main_module
(),
&
is_clip_scalar
));
}
}
TEST_CASE
(
dequantizelinear
)
TEST_CASE
(
dequantizelinear
)
...
...
test/simplify_algebra_test.cpp
View file @
f787d5bd
...
@@ -2189,16 +2189,16 @@ TEST_CASE(simplify_split_between_add)
...
@@ -2189,16 +2189,16 @@ TEST_CASE(simplify_split_between_add)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
}
TEST_CASE
(
simplify_dot_horiz
)
void
test_dot_horiz
(
migraphx
::
shape
::
type_t
type
,
const
std
::
string
&
dot_type
)
{
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_
type
,
{
3
,
2
,
2
}};
auto
s
=
migraphx
::
shape
{
type
,
{
3
,
2
,
2
}};
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
a
);
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
b
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
b
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m1
.
add_instruction
(
pass_op
{},
sum
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
}
...
@@ -2210,7 +2210,7 @@ TEST_CASE(simplify_dot_horiz)
...
@@ -2210,7 +2210,7 @@ TEST_CASE(simplify_dot_horiz)
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"
dot
"
),
input
,
concat
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
dot
_type
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
auto
y
=
m2
.
add_instruction
(
...
@@ -2221,6 +2221,10 @@ TEST_CASE(simplify_dot_horiz)
...
@@ -2221,6 +2221,10 @@ TEST_CASE(simplify_dot_horiz)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
}
TEST_CASE
(
simplify_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int32_type
,
"dot"
);
}
TEST_CASE
(
simplify_quant_dot_horiz
)
{
test_dot_horiz
(
migraphx
::
shape
::
int8_type
,
"quant_dot"
);
}
TEST_CASE
(
simplify_dot_horiz_same_constant
)
TEST_CASE
(
simplify_dot_horiz_same_constant
)
{
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment