Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
794d0e76
Commit
794d0e76
authored
Dec 08, 2023
by
Gyula Zakor
Browse files
Add uint8 to quant_dot
parent
99ebfe11
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
28 deletions
+17
-28
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+6
-5
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+7
-19
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+1
-3
No files found.
src/include/migraphx/op/quant_dot.hpp
View file @
794d0e76
...
@@ -44,9 +44,10 @@ struct quant_dot
...
@@ -44,9 +44,10 @@ struct quant_dot
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
suppported_types
=
{
shape
::
int8_type
,
shape
::
uint8_type
};
if
(
not
contains
(
suppported_types
,
t
))
{
{
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_DOT: only support data type int8_t
and uint8_t
"
);
}
}
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
...
...
src/onnx/parse_matmul.cpp
View file @
794d0e76
...
@@ -113,32 +113,20 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -113,32 +113,20 @@ struct parse_matmul : op_parser<parse_matmul>
}
}
}
}
// MatMulInteger can accept uint8 as input type or have zero point values
// parse a_zero_point and b_zero_point values
// In these case fall back to dot with half float inputs
if
(
args
.
size
()
>
2
)
auto
ba0_type
=
ba0
->
get_shape
().
type
();
{
auto
ba1_type
=
ba1
->
get_shape
().
type
();
auto
has_a0_zero_point
=
args
.
size
()
>
2
;
auto
has_a1_zero_point
=
args
.
size
()
>
3
;
if
(
is_quant_dot
and
(
ba0_type
==
migraphx
::
shape
::
uint8_type
or
ba1_type
==
migraphx
::
shape
::
uint8_type
or
has_a0_zero_point
))
{
// gpu implementation (gemm) only accepts floating point types for dot
ba0
=
info
.
add_instruction
(
ba0
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
ba0
);
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
float_type
}}),
ba0
);
ba1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
half_type
}}),
ba1
);
if
(
has_a0_zero_point
)
{
ba0
=
info
.
add_common_op
(
"sub"
,
ba0
,
args
[
2
]);
ba0
=
info
.
add_common_op
(
"sub"
,
ba0
,
args
[
2
]);
}
if
(
args
.
size
()
>
3
)
if
(
has_a1_zero_point
)
{
{
ba1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
float_type
}}),
ba1
);
ba1
=
info
.
add_common_op
(
"sub"
,
ba1
,
args
[
3
]);
ba1
=
info
.
add_common_op
(
"sub"
,
ba1
,
args
[
3
]);
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
"dot"
),
ba0
,
ba1
);
dot_res
=
info
.
add_instruction
(
make_op
(
"dot"
),
ba0
,
ba1
);
dot_res
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
shape
::
int32_type
}}),
dot_res
);
}
}
else
else
{
{
...
...
src/targets/gpu/gemm_impl.cpp
View file @
794d0e76
...
@@ -196,7 +196,7 @@ struct gemm_impl
...
@@ -196,7 +196,7 @@ struct gemm_impl
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg_type
;
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
if
(
output_type
==
rocblas_datatype_i8_r
or
output_type
==
rocblas_datatype_u8_r
)
{
{
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
...
...
src/targets/gpu/target.cpp
View file @
794d0e76
...
@@ -140,6 +140,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -140,6 +140,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq
{},
simplify_qdq
{},
enable_pass
(
not
mlir_enabled
(),
rewrite_quantization
{}),
enable_pass
(
not
mlir_enabled
(),
rewrite_quantization
{}),
dead_code_elimination
{},
dead_code_elimination
{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot
eliminate_data_type
{{
migraphx
::
shape
::
uint8_type
},
shape
::
float_type
,
{
"quant_dot"
}},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
simplify_reshapes
{},
simplify_reshapes
{},
eliminate_identity
{},
eliminate_identity
{},
...
...
test/onnx/verify_onnx.cpp
View file @
794d0e76
...
@@ -1218,9 +1218,7 @@ TEST_CASE(lpnormalization_2norm)
...
@@ -1218,9 +1218,7 @@ TEST_CASE(lpnormalization_2norm)
TEST_CASE
(
matmulinteger_unsigned_test
)
TEST_CASE
(
matmulinteger_unsigned_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"matmulinteger_unsigned_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"matmulinteger_unsigned_test.onnx"
);
migraphx
::
compile_options
gpu_opt
;
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
gpu_opt
.
offload_copy
=
true
;
p
.
compile
(
migraphx
::
make_target
(
"ref"
),
gpu_opt
);
migraphx
::
shape
s0
{
migraphx
::
shape
::
uint8_type
,
{
4
,
3
}};
migraphx
::
shape
s0
{
migraphx
::
shape
::
uint8_type
,
{
4
,
3
}};
std
::
vector
<
uint8_t
>
data0
=
{
11
,
7
,
3
,
10
,
6
,
2
,
9
,
5
,
1
,
8
,
4
,
0
};
std
::
vector
<
uint8_t
>
data0
=
{
11
,
7
,
3
,
10
,
6
,
2
,
9
,
5
,
1
,
8
,
4
,
0
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment