Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
752cb65a
Commit
752cb65a
authored
Dec 06, 2023
by
Umang Yadav
Browse files
WIP mobilenet
parent
4926f035
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
17 deletions
+23
-17
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+2
-2
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+8
-4
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+4
-4
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+4
-4
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+5
-3
No files found.
src/include/migraphx/op/dequantizelinear.hpp
View file @
752cb65a
...
@@ -72,8 +72,8 @@ struct dequantizelinear
...
@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
double
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
static_cast
<
double
>
(
zero_pts
[
i
]))
*
scales
[
i
];
scales
[
i
];
});
});
});
});
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
752cb65a
...
@@ -58,6 +58,10 @@ struct quantizelinear
...
@@ -58,6 +58,10 @@ struct quantizelinear
{
{
return
{
inputs
[
2
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
return
{
inputs
[
2
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
}
if
(
inputs
[
0
].
type
()
==
shape
::
float_type
)
{
return
{
shape
::
fp8e4m3fnuz_type
,
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
return
{
shape
::
uint8_type
,
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
return
{
shape
::
uint8_type
,
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
}
...
@@ -80,10 +84,10 @@ struct quantizelinear
...
@@ -80,10 +84,10 @@ struct quantizelinear
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
double
quantized
=
static_cast
<
double
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
static_cast
<
double
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
output
[
i
]
=
std
::
max
(
static_cast
<
double
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
std
::
min
(
static_cast
<
double
>
(
max_value
),
quantized
));
});
});
});
});
});
});
...
...
src/onnx/onnx_parser.cpp
View file @
752cb65a
...
@@ -549,7 +549,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -549,7 +549,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
UZ
:
{
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
:
{
std
::
vector
<
int32_t
>
data_int32
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
int32_t
>
data_int32
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
data_fp8
;
std
::
vector
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
data_fp8
;
std
::
transform
(
data_int32
.
begin
(),
std
::
transform
(
data_int32
.
begin
(),
...
@@ -560,7 +560,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -560,7 +560,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
}
case
onnx
::
TensorProto
::
FLOAT8E5M2FNUZ
:
case
onnx
::
TensorProto
::
FLOAT8E5M2FNUZ
:
case
onnx
::
TensorProto
::
FLOAT8E5M2
:
case
onnx
::
TensorProto
::
FLOAT8E5M2
:
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
:
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
UZ
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
...
@@ -625,11 +625,11 @@ shape::type_t get_type(int dtype)
...
@@ -625,11 +625,11 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
13
:
return
shape
::
uint64_type
;
case
1
8
:
return
shape
::
fp8e4m3fnuz_type
;
case
1
7
:
return
shape
::
fp8e4m3fnuz_type
;
case
14
:
case
14
:
case
15
:
case
15
:
case
16
:
case
16
:
case
1
7
:
case
1
8
:
case
19
:
case
19
:
case
20
:
case
20
:
default:
{
default:
{
...
...
src/rewrite_quantization.cpp
View file @
752cb65a
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
}
int64_t
max_quant
=
0
;
double
max_quant
=
0
;
int64_t
min_quant
=
0
;
double
min_quant
=
0
;
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
ins
->
get_shape
().
visit_type
([
&
](
auto
qt
)
{
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
if
(
enabled
(
MIGRAPHX_ENABLE_CK_WORKAROUNDS
{}))
{
{
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
double
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
std
::
vector
<
double
>
max_data
(
s
.
elements
(),
max_quant
);
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
}
}
...
...
src/simplify_qdq.cpp
View file @
752cb65a
...
@@ -125,9 +125,11 @@ struct match_find_quantizable_ops
...
@@ -125,9 +125,11 @@ struct match_find_quantizable_ops
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
// Only INT8 or FP8 type currently supported
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
migraphx
::
shape
::
fp8e4m3fnuz_type
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
migraphx
::
shape
::
int8_type
};
if
(
not
contains
(
supported_types
,
dq1
->
inputs
().
front
()
->
get_shape
().
type
())
or
not
contains
(
supported_types
,
dq2
->
inputs
().
front
()
->
get_shape
().
type
()))
return
;
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment