Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
01342ae1
Unverified
Commit
01342ae1
authored
Jun 22, 2023
by
Zhuoran Yin
Committed by
GitHub
Jun 22, 2023
Browse files
[mlir] Adding mlir quant_dot operator support (#1816)
Add mlir quant_dot operator support
parent
c5cd87ce
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
5 deletions
+32
-5
Dockerfile
Dockerfile
+1
-1
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+3
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+0
-2
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+28
-1
No files found.
Dockerfile
View file @
01342ae1
...
...
@@ -113,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
a997d5f51314b45d7a4c04f1599966dcf53f9b4d
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
f5ab829b1a46eca600eb8e9df4eaa9944c845f07
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
01342ae1
...
...
@@ -139,7 +139,8 @@ struct find_mlir_op
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
...
...
@@ -205,6 +206,7 @@ struct find_mlir_op
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"relu"
,
"dequantizelinear"
,
...
...
src/targets/gpu/mlir.cpp
View file @
01342ae1
...
...
@@ -244,8 +244,6 @@ struct mlir_program
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
{
if
(
i
<
0
)
MIGRAPHX_THROW
(
"MLIR cant handle negative values since they are ambiguous"
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
}
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
...
...
test/gpu/mlir.cpp
View file @
01342ae1
...
...
@@ -187,12 +187,39 @@ module {
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
quant_dot_add
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
arg0
=
m
.
add_parameter
(
"arg0"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
5
,
4
}});
auto
arg1
=
m
.
add_parameter
(
"arg1"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
4
,
3
}});
auto
arg2
=
m
.
add_parameter
(
"arg2"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
3
}});
auto
conv
=
m
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
arg0
,
arg1
);
auto
add
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
arg2
);
m
.
add_return
({
add
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
dot_add
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%0 = migraphx.dot(%arg0, %arg1) :
(
tensor<1x5x4xf32>, tensor<1x4x3xf32>
)
-> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment