Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
22f60c24
Unverified
Commit
22f60c24
authored
Dec 07, 2023
by
Umang Yadav
Committed by
GitHub
Dec 07, 2023
Browse files
Enable simplify qdq to work with FP8 types (#2528)
parent
dfc18d6c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
27 deletions
+86
-27
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+2
-2
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+4
-4
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+4
-4
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+20
-16
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+56
-1
No files found.
src/include/migraphx/op/dequantizelinear.hpp
View file @
22f60c24
...
@@ -72,8 +72,8 @@ struct dequantizelinear
...
@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
double
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
static_cast
<
double
>
(
zero_pts
[
i
]))
*
scales
[
i
];
scales
[
i
];
});
});
});
});
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
22f60c24
...
@@ -80,10 +80,10 @@ struct quantizelinear
...
@@ -80,10 +80,10 @@ struct quantizelinear
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
double
quantized
=
static_cast
<
double
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
static_cast
<
double
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
output
[
i
]
=
std
::
max
(
static_cast
<
double
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
std
::
min
(
static_cast
<
double
>
(
max_value
),
quantized
));
});
});
});
});
});
});
...
...
src/rewrite_quantization.cpp
View file @
22f60c24
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
}
int64_t
max_quant
=
0
;
double
max_quant
=
0
;
int64_t
min_quant
=
0
;
double
min_quant
=
0
;
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
double
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
std
::
vector
<
double
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
}
...
...
src/simplify_qdq.cpp
View file @
22f60c24
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
...
@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
// occur between dequantizelinear and the quantized op
static
auto
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
_arg
)
{
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
prev_ins
=
qop_arg
;
auto
next_ins
=
dqins
;
std
::
vector
<
instruction_ref
>
ins_inbetween
;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
while
(
next_ins
!=
qop
)
// instructions
while
(
prev_ins
!=
dqins
)
{
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
ins_inbetween
.
push_back
(
prev_ins
);
{
prev_ins
=
prev_ins
->
inputs
().
front
();
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
}
auto
qinp
=
dqins
->
inputs
().
front
();
next_ins
=
next_ins
->
outputs
().
front
();
for
(
auto
ins
:
reverse_iterator_for
(
ins_inbetween
))
{
qinp
=
m
.
insert_instruction
(
dqins
,
(
*
ins
)
->
get_operator
(),
{
qinp
});
}
}
return
qinp
;
return
qinp
;
}
}
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
...
@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 or FP8 type currently supported
// Only INT8 type currently supported
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
migraphx
::
shape
::
fp8e4m3fnuz_type
,
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
migraphx
::
shape
::
int8_type
};
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
if
(
not
contains
(
supported_types
,
dq1
->
inputs
().
front
()
->
get_shape
().
type
())
or
not
contains
(
supported_types
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()))
return
;
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
...
@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
);
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
_args
[
0
]
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
);
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
_args
[
1
]
);
instruction_ref
dq
;
instruction_ref
dq
;
instruction_ref
out_scale
;
instruction_ref
out_scale
;
instruction_ref
zero_point
;
instruction_ref
zero_point
;
...
...
test/simplify_qdq_test.cpp
View file @
22f60c24
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
...
@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
dot_add_multiple_dq_use
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
32
,
1
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
d1_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d1
);
auto
d1_tmb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
d1_t
);
auto
d1_tmbc
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
d1_tmb
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
dot_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1_tmbc
,
d2
);
auto
q3
=
add_quantize_op
(
m1
,
"quantizelinear"
,
dot_1
,
scale
,
zero
);
auto
d3
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q3
,
scale
,
zero
);
auto
dot_2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d3
,
d1
);
auto
add
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
dot_2
,
d1
});
m1
.
add_return
({
add
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q1
);
auto
q1_tmb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
32
,
32
}}}),
q1_t
);
auto
q1_tmbc
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
q1_tmb
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1_tmbc
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_1
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_1
,
out_scale
);
auto
d3_q
=
add_quantize_op
(
m2
,
"quantizelinear"
,
d3
,
scale
,
zero
);
auto
dot_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
d3_q
,
q1
);
auto
out_scale_2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot_2
->
get_shape
().
lens
());
auto
d4
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot_2
,
out_scale_2
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d4
,
t1
);
m2
.
add_return
({
add
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv
)
TEST_CASE
(
conv
)
{
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
...
@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto
mod1
=
create_module
();
auto
mod1
=
create_module
();
auto
mod2
=
create_module
();
auto
mod2
=
create_module
();
run_pass
(
mod2
);
run_pass
(
mod2
);
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
auto
match_qdq
=
migraphx
::
match
::
name
(
"dequantizelinear"
)(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment