Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
99ebfe11
Commit
99ebfe11
authored
Dec 01, 2023
by
Gyula Zakor
Browse files
Update MatMulInteger op parsing
parent
7e53592e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
74 additions
and
3 deletions
+74
-3
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+34
-2
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+17
-0
test/onnx/matmulinteger_unsigned_test.onnx
test/onnx/matmulinteger_unsigned_test.onnx
+0
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+23
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
No files found.
src/onnx/parse_matmul.cpp
View file @
99ebfe11
...
@@ -62,9 +62,10 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -62,9 +62,10 @@ struct parse_matmul : op_parser<parse_matmul>
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
}
auto
is_quant_dot
=
opd
.
op_name
==
"quant_dot"
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
{
if
(
opd
.
op_name
==
"
quant_dot
"
)
if
(
is_
quant_dot
)
{
{
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
);
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
);
}
}
...
@@ -111,7 +112,38 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -111,7 +112,38 @@ struct parse_matmul : op_parser<parse_matmul>
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
// MatMulInteger can accept uint8 as input type or have zero point values
// In these case fall back to dot with half float inputs
auto
ba0_type
=
ba0
->
get_shape
().
type
();
auto
ba1_type
=
ba1
->
get_shape
().
type
();
auto
has_a0_zero_point
=
args
.
size
()
>
2
;
auto
has_a1_zero_point
=
args
.
size
()
>
3
;
if
(
is_quant_dot
and
(
ba0_type
==
migraphx
::
shape
::
uint8_type
or
ba1_type
==
migraphx
::
shape
::
uint8_type
or
has_a0_zero_point
))
{
// gpu implementation (gemm) only accepts floating point types for dot
ba0
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
ba0
);
ba1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
ba1
);
if
(
has_a0_zero_point
)
{
ba0
=
info
.
add_common_op
(
"sub"
,
ba0
,
args
[
2
]);
}
if
(
has_a1_zero_point
)
{
ba1
=
info
.
add_common_op
(
"sub"
,
ba1
,
args
[
3
]);
}
dot_res
=
info
.
add_instruction
(
make_op
(
"dot"
),
ba0
,
ba1
);
dot_res
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
int32_type
}}),
dot_res
);
}
else
{
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
}
}
// squeeze the appended or prepended dimensions
// squeeze the appended or prepended dimensions
...
...
test/onnx/gen_onnx.py
View file @
99ebfe11
...
@@ -4866,6 +4866,23 @@ def matmulinteger_dyn_error():
...
@@ -4866,6 +4866,23 @@ def matmulinteger_dyn_error():
return
([
node
],
[
m1
,
m2
],
[
y
])
return
([
node
],
[
m1
,
m2
],
[
y
])
@
onnx_test
()
def
matmulinteger_unsigned_test
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
UINT8
,
[
4
,
3
])
m2
=
helper
.
make_tensor_value_info
(
'2'
,
TensorProto
.
UINT8
,
[
3
,
2
])
zp1
=
helper
.
make_tensor
(
'3'
,
TensorProto
.
UINT8
,
[],
[
12
])
zp2
=
helper
.
make_tensor
(
'4'
,
TensorProto
.
UINT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT32
,
[
4
,
2
])
node
=
onnx
.
helper
.
make_node
(
'MatMulInteger'
,
inputs
=
[
'1'
,
'2'
,
'3'
,
'4'
],
outputs
=
[
'y'
],
)
return
([
node
],
[
m1
,
m2
],
[
y
],
[
zp1
,
zp2
])
@
onnx_test
()
@
onnx_test
()
def
max_test
():
def
max_test
():
a
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
a
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
test/onnx/matmulinteger_unsigned_test.onnx
0 → 100644
View file @
99ebfe11
File added
test/onnx/verify_onnx.cpp
View file @
99ebfe11
...
@@ -1215,6 +1215,29 @@ TEST_CASE(lpnormalization_2norm)
...
@@ -1215,6 +1215,29 @@ TEST_CASE(lpnormalization_2norm)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
}
TEST_CASE
(
matmulinteger_unsigned_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"matmulinteger_unsigned_test.onnx"
);
migraphx
::
compile_options
gpu_opt
;
gpu_opt
.
offload_copy
=
true
;
p
.
compile
(
migraphx
::
make_target
(
"ref"
),
gpu_opt
);
migraphx
::
shape
s0
{
migraphx
::
shape
::
uint8_type
,
{
4
,
3
}};
std
::
vector
<
uint8_t
>
data0
=
{
11
,
7
,
3
,
10
,
6
,
2
,
9
,
5
,
1
,
8
,
4
,
0
};
migraphx
::
shape
s1
{
migraphx
::
shape
::
uint8_type
,
{
3
,
2
}};
std
::
vector
<
uint8_t
>
data1
=
{
1
,
4
,
2
,
5
,
3
,
6
};
migraphx
::
parameter_map
pp
;
pp
[
"1"
]
=
migraphx
::
argument
(
s0
,
data0
.
data
());
pp
[
"2"
]
=
migraphx
::
argument
(
s1
,
data1
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
int32_t
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int32_t
>
gold
=
{
-
38
,
-
83
,
-
44
,
-
98
,
-
50
,
-
113
,
-
56
,
-
128
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
mean_broadcast_test
)
TEST_CASE
(
mean_broadcast_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"mean_broadcast_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"mean_broadcast_test.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
99ebfe11
...
@@ -134,7 +134,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
...
@@ -134,7 +134,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_hardmax_example_cpu'
)
backend_test
.
exclude
(
r
'test_hardmax_example_cpu'
)
backend_test
.
exclude
(
r
'test_hardmax_negative_axis_cpu'
)
backend_test
.
exclude
(
r
'test_hardmax_negative_axis_cpu'
)
backend_test
.
exclude
(
r
'test_hardmax_one_hot_cpu'
)
backend_test
.
exclude
(
r
'test_hardmax_one_hot_cpu'
)
backend_test
.
exclude
(
r
'test_matmulinteger_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_maxunpool_export_with_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_maxunpool_export_with_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_maxunpool_export_without_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_maxunpool_export_without_output_shape_cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment