Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
22f60c24
Unverified
Commit
22f60c24
authored
Dec 07, 2023
by
Umang Yadav
Committed by
GitHub
Dec 07, 2023
Browse files
Enable simplify qdq to work with FP8 types (#2528)
parent
dfc18d6c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
27 deletions
+86
-27
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+2
-2
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+4
-4
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+4
-4
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+20
-16
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+56
-1
No files found.
src/include/migraphx/op/dequantizelinear.hpp
View file @
22f60c24
...
...
@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
double
>
(
input
[
i
])
-
static_cast
<
double
>
(
zero_pts
[
i
]))
*
scales
[
i
];
});
});
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
22f60c24
...
...
@@ -80,10 +80,10 @@ struct quantizelinear
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
double
quantized
=
static_cast
<
double
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
double
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
double
>
(
min_value
),
std
::
min
(
static_cast
<
double
>
(
max_value
),
quantized
));
});
});
});
...
...
src/rewrite_quantization.cpp
View file @
22f60c24
...
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
int64_t
max_quant
=
0
;
int64_t
min_quant
=
0
;
double
max_quant
=
0
;
double
min_quant
=
0
;
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
...
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
std
::
vector
<
double
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
double
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
...
...
src/simplify_qdq.cpp
View file @
22f60c24
...
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
_arg
)
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
next_ins
=
dqins
;
while
(
next_ins
!=
qop
)
auto
prev_ins
=
qop_arg
;
std
::
vector
<
instruction_ref
>
ins_inbetween
;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while
(
prev_ins
!=
dqins
)
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
next_ins
=
next_ins
->
outputs
().
front
();
ins_inbetween
.
push_back
(
prev_ins
);
prev_ins
=
prev_ins
->
inputs
().
front
();
}
auto
qinp
=
dqins
->
inputs
().
front
();
for
(
auto
ins
:
reverse_iterator_for
(
ins_inbetween
))
{
qinp
=
m
.
insert_instruction
(
dqins
,
(
*
ins
)
->
get_operator
(),
{
qinp
});
}
return
qinp
;
}
...
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
// Only INT8 or FP8 type currently supported
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
int8_type
};
if
(
not
contains
(
supported_types
,
dq1
->
inputs
().
front
()
->
get_shape
().
type
())
or
not
contains
(
supported_types
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()))
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
);
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
_args
[
0
]
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
_args
[
1
]
);
instruction_ref
dq
;
instruction_ref
out_scale
;
instruction_ref
zero_point
;
...
...
test/simplify_qdq_test.cpp
View file @
22f60c24
...
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_add_multiple_dq_use
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
32
,
1
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
d1_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d1
);
auto
d1_tmb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
d1_t
);
auto
d1_tmbc
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
d1_tmb
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
dot_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1_tmbc
,
d2
);
auto
q3
=
add_quantize_op
(
m1
,
"quantizelinear"
,
dot_1
,
scale
,
zero
);
auto
d3
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q3
,
scale
,
zero
);
auto
dot_2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d3
,
d1
);
auto
add
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
dot_2
,
d1
});
m1
.
add_return
({
add
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q1
);
auto
q1_tmb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
q1_t
);
auto
q1_tmbc
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
q1_tmb
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1_tmbc
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_1
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_1
,
out_scale
);
auto
d3_q
=
add_quantize_op
(
m2
,
"quantizelinear"
,
d3
,
scale
,
zero
);
auto
dot_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
d3_q
,
q1
);
auto
out_scale_2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_2
->
get_shape
().
lens
());
auto
d4
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_2
,
out_scale_2
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d4
,
t1
);
m2
.
add_return
({
add
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
...
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto
mod1
=
create_module
();
auto
mod2
=
create_module
();
run_pass
(
mod2
);
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment