Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
01342ae1
Unverified
Commit
01342ae1
authored
Jun 22, 2023
by
Zhuoran Yin
Committed by
GitHub
Jun 22, 2023
Browse files
[mlir] Adding mlir quant_dot operator support (#1816)
Add mlir quant_dot operator support
parent
c5cd87ce
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
5 deletions
+32
-5
Dockerfile
Dockerfile
+1
-1
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+3
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+0
-2
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+28
-1
No files found.
Dockerfile
View file @
01342ae1
...
...
@@ -113,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
a997d5f51314b45d7a4c04f1599966dcf53f9b4d
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/rocMLIR@
f5ab829b1a46eca600eb8e9df4eaa9944c845f07
-DBUILD_MIXR_TARGET
=
On
-DLLVM_ENABLE_ZSTD
=
Off
-DLLVM_ENABLE_THREADS
=
Off
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
01342ae1
...
...
@@ -139,7 +139,8 @@ struct find_mlir_op
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
is_mlir_conv
()).
bind
(
"gemm_based_op"
));
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
...
...
@@ -205,6 +206,7 @@ struct find_mlir_op
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"relu"
,
"dequantizelinear"
,
...
...
src/targets/gpu/mlir.cpp
View file @
01342ae1
...
...
@@ -244,8 +244,6 @@ struct mlir_program
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
{
if
(
i
<
0
)
MIGRAPHX_THROW
(
"MLIR cant handle negative values since they are ambiguous"
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
}
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
...
...
test/gpu/mlir.cpp
View file @
01342ae1
...
...
@@ -187,12 +187,39 @@ module {
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
quant_dot_add
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
arg0
=
m
.
add_parameter
(
"arg0"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
5
,
4
}});
auto
arg1
=
m
.
add_parameter
(
"arg1"
,
{
migraphx
::
shape
::
int8_type
,
{
1
,
4
,
3
}});
auto
arg2
=
m
.
add_parameter
(
"arg2"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
3
}});
auto
conv
=
m
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
arg0
,
arg1
);
auto
add
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv
,
arg2
);
m
.
add_return
({
add
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
dot_add
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%0 = migraphx.dot(%arg0, %arg1) :
(
tensor<1x5x4xf32>, tensor<1x4x3xf32>
)
-> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment