Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bb7f65d9
"src/vscode:/vscode.git/clone" did not exist on "be0bf36d6163484504ae5dc6c3474af0f06d11fd"
Commit
bb7f65d9
authored
Jan 31, 2023
by
Paul
Browse files
Fix conflicts
parents
b9ec1d6d
a4b82653
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
435 additions
and
75 deletions
+435
-75
.github/workflows/sync-onnxrt-main.yaml
.github/workflows/sync-onnxrt-main.yaml
+4
-3
Dockerfile
Dockerfile
+4
-8
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+1
-1
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+40
-15
src/include/migraphx/optimize_module.hpp
src/include/migraphx/optimize_module.hpp
+48
-0
src/optimize_module.cpp
src/optimize_module.cpp
+49
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+53
-28
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
+2
-2
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+12
-7
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-10
test/onnx/gather_dyn_test.onnx
test/onnx/gather_dyn_test.onnx
+0
-0
test/onnx/gather_scalar_test.onnx
test/onnx/gather_scalar_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+34
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+40
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+71
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+72
-0
No files found.
.github/workflows/sync-onnxrt-main.yaml
View file @
bb7f65d9
name
:
Onnxruntime main weekly sync
name
:
Onnxruntime main weekly sync
on
:
on
:
schedule
:
schedule
:
-
cron
:
"
05
09
*
*
5
"
-
cron
:
"
05
17
*
*
1
"
jobs
:
jobs
:
runs-on
:
ubuntu-latest
sync
:
sync
:
steps
:
steps
:
-
uses
:
actions/checkout@v3
-
uses
:
actions/checkout@v3
with
:
with
:
ref
:
develop
ref
:
develop
path
:
../
path
:
../
get_date
:
get_date
:
steps
:
steps
:
...
...
Dockerfile
View file @
bb7f65d9
...
@@ -95,20 +95,16 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
...
@@ -95,20 +95,16 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
RUN
cget
-p
$PREFIX
install
ccache@v4.1
-DENABLE_TESTING
=
OFF
RUN
cget
-p
$PREFIX
install
ccache@v4.1
-DENABLE_TESTING
=
OFF
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.24.3
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.24.3
RUN
export
ONNXRT_COMMIT
=
$(
cat
test
/onnx/.onnxrt-commit
)
COPY
./
test/onnx/.onnxrt-commit
/
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_BRANCH=main
ARG
ONNXRUNTIME_BRANCH=main
ARG
ONNXRUNTIME_COMMIT=$ONNXRT_COMMIT
ARG
ONNXRUNTIME_COMMIT
# Let us know which commit where're using for CI
RUN
echo
"Onnxruntime Commit:"
&&
echo
$ONNXRUNTIME_COMMIT
RUN
git clone
--single-branch
--branch
${
ONNXRUNTIME_BRANCH
}
--recursive
${
ONNXRUNTIME_REPO
}
onnxruntime
&&
\
RUN
git clone
--single-branch
--branch
${
ONNXRUNTIME_BRANCH
}
--recursive
${
ONNXRUNTIME_REPO
}
onnxruntime
&&
\
cd
onnxruntime
&&
\
cd
onnxruntime
&&
\
git checkout
${
ONNXRUNTIME_COMMIT
}
&&
\
if
[
-z
"
$ONNXRUNTIME_COMMIT
"
]
;
then
git checkout
$(
cat
/.onnxrt-commit
)
;
else
git checkout
${
ONNXRUNTIME_COMMIT
}
;
fi
&&
\
/bin/sh dockerfiles/scripts/install_common_deps.sh
/bin/sh /onnxruntime/dockerfiles/scripts/install_common_deps.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
...
...
src/CMakeLists.txt
View file @
bb7f65d9
...
@@ -64,6 +64,7 @@ add_library(migraphx
...
@@ -64,6 +64,7 @@ add_library(migraphx
normalize_ops.cpp
normalize_ops.cpp
op_enums.cpp
op_enums.cpp
operation.cpp
operation.cpp
optimize_module.cpp
opt/memory_coloring.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp
pad_calc.cpp
...
...
src/include/migraphx/match/layernorm.hpp
View file @
bb7f65d9
...
@@ -52,7 +52,7 @@ struct layernorm_matcher
...
@@ -52,7 +52,7 @@ struct layernorm_matcher
auto
sqrt_add_eps
(
const
std
::
string
&
name
)
const
auto
sqrt_add_eps
(
const
std
::
string
&
name
)
const
{
{
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
return
skip_broadcasts
(
f
(
name
)(
arg
(
0
)(
a
dd_eps
)));
return
skip_broadcasts
(
f
(
name
)(
arg
(
0
)(
a
ny_of
(
add_eps
,
variance
())
)));
}
}
auto
layernorm_onnx
()
const
auto
layernorm_onnx
()
const
...
...
src/include/migraphx/matcher.hpp
View file @
bb7f65d9
...
@@ -615,7 +615,7 @@ inline auto var(std::string s)
...
@@ -615,7 +615,7 @@ inline auto var(std::string s)
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
nullopt
;
return
it
->
second
;
return
it
->
second
;
});
});
...
...
src/include/migraphx/op/gather.hpp
View file @
bb7f65d9
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <array>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -61,35 +62,59 @@ struct gather
...
@@ -61,35 +62,59 @@ struct gather
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
shape
data
=
inputs
[
0
];
auto
type
=
inputs
[
0
].
type
();
shape
indices
=
inputs
[
1
];
lens
.
erase
(
lens
.
begin
()
+
axis
);
auto
type
=
data
.
type
();
if
(
not
inputs
[
1
].
scalar
())
// If index_dims is dynamic, convert the data to dynamic too.
if
(
indices
.
dynamic
())
{
{
auto
ind_lens
=
inputs
[
1
].
lens
();
data
=
data
.
to_dynamic
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
}
if
(
data
.
dynamic
())
// for scalar output
if
(
lens
.
empty
())
{
{
return
{
type
};
auto
dims
=
data
.
dyn_dims
();
dims
.
erase
(
dims
.
begin
()
+
axis
);
if
(
not
indices
.
scalar
())
{
auto
index_dims
=
indices
.
to_dynamic
().
dyn_dims
();
dims
.
insert
(
dims
.
begin
()
+
axis
,
index_dims
.
begin
(),
index_dims
.
end
());
}
return
{
type
,
dims
};
}
}
else
{
// Both data and indices are static. indices may be scalar
auto
lens
=
data
.
lens
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
return
{
type
,
lens
};
if
(
not
indices
.
scalar
())
{
auto
ind_lens
=
indices
.
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
// for scalar output
if
(
lens
.
empty
())
{
return
{
type
};
}
return
{
type
,
lens
};
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
// negative axis means counting dimensions from back
// negative axis means counting dimensions from back
auto
lens
=
args
[
0
].
get_shape
().
lens
();
auto
lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
axis_dim_size
=
lens
[
axis
];
std
::
size_t
axis_dim_size
=
lens
[
axis
];
// max dimension in axis
// max dimension in axis
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
out
put_shape
.
scalar
())
if
(
dyn_out
.
com
put
ed
_shape
.
scalar
())
{
{
auto
in_index
=
indices
.
front
();
auto
in_index
=
indices
.
front
();
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
...
...
src/include/migraphx/optimize_module.hpp
0 → 100644
View file @
bb7f65d9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
/**
* Runs several passes in a loop
*/
struct
optimize_module
{
std
::
string
name
()
const
{
return
"optimize_module"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/optimize_module.cpp
0 → 100644
View file @
bb7f65d9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
optimize_module
::
apply
(
module_pass_manager
&
mpm
)
const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_ops.cpp
View file @
bb7f65d9
...
@@ -553,11 +553,13 @@ struct find_gemm_pointwise
...
@@ -553,11 +553,13 @@ struct find_gemm_pointwise
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
precompile_name
(
"pointwise"
)(
auto
gemm_op
=
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
);
auto
binary_op
=
match
::
all_of
(
match
::
nargs
(
3
),
match
::
nargs
(
3
),
match
::
either_arg
(
0
,
1
)(
match
::
either_arg
(
0
,
1
)(
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
gemm_op
));
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
auto
unary_op
=
match
::
all_of
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
gemm_op
));
return
precompile_name
(
"pointwise"
)(
match
::
any_of
(
binary_op
,
unary_op
));
}
}
// TODO: Move to matcher.hpp
// TODO: Move to matcher.hpp
...
@@ -589,61 +591,84 @@ struct find_gemm_pointwise
...
@@ -589,61 +591,84 @@ struct find_gemm_pointwise
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
}
static
auto
match_mul
(
const
std
::
string
&
input
)
{
auto
mul
=
match_mul_const
(
match_param
(
input
),
"alpha"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
mul
));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
{
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
if
(
names
.
size
()
==
1
)
auto
mr
=
match
::
match_instruction
(
{
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_mul
(
names
[
input
]));
if
(
mr
.
result
==
pm
->
end
())
if
(
mr
.
result
==
pm
->
end
())
return
false
;
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
return
true
;
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
}
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
)
)
else
if
(
names
.
size
()
==
2
)
{
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
unsigned
output
=
input
==
0
?
1
:
0
;
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
else
{
return
false
;
}
}
return
true
;
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
gemm
=
any_cast
<
rocblas_gemm
<
op
::
dot
>>
(
gemm_ins
->
get_operator
());
auto
gemm
=
any_cast
<
rocblas_gemm
<
op
::
dot
>>
(
gemm_ins
->
get_operator
());
// Already fused gemm
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
return
;
gemm
.
beta
=
1
;
if
(
ins
->
inputs
().
size
()
==
3
)
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
make_op
(
"contiguous"
);
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
inputs
.
push_back
(
c_ins
);
if
(
ins
->
inputs
().
size
()
==
3
)
{
auto
c_ins
=
r
.
instructions
[
"c"
];
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
make_op
(
"contiguous"
);
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
inputs
.
push_back
(
c_ins
);
}
inputs
.
push_back
(
ins
->
inputs
().
back
());
inputs
.
push_back
(
ins
->
inputs
().
back
());
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
...
...
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
View file @
bb7f65d9
...
@@ -30,14 +30,14 @@
...
@@ -30,14 +30,14 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
namespace
gpu
{
struct
prefuse_ops
struct
prefuse_ops
{
{
std
::
string
name
()
const
{
return
"gpu::prefuse_ops"
;
}
std
::
string
name
()
const
{
return
"gpu::prefuse_ops"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
m
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
bb7f65d9
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -90,7 +92,9 @@ struct find_layernorm
...
@@ -90,7 +92,9 @@ struct find_layernorm
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
float
eps
=
0
;
if
(
contains
(
r
.
instructions
,
"eps"
))
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
}
}
...
@@ -100,24 +104,25 @@ struct find_add_layernorm
...
@@ -100,24 +104,25 @@ struct find_add_layernorm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
layernorm
()(
return
match
::
name
(
"gpu::prelayernorm"
)(
match
::
args
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
match
::
var
(
"x"
)(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
(
);
auto
op
=
any_cast
<
layernorm
>
(
ins
->
get_operator
()
);
m
.
replace_instruction
(
ins
,
add_layernorm
{
eps
},
add_ins
->
inputs
());
m
.
replace_instruction
(
ins
,
add_layernorm
{
op
.
epsilon
},
add_ins
->
inputs
());
}
}
};
};
}
// namespace
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
{
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
bb7f65d9
...
@@ -38,6 +38,7 @@
...
@@ -38,6 +38,7 @@
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
...
@@ -118,21 +119,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -118,21 +119,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gelu
{},
rewrite_gelu
{},
dead_code_elimination
{},
optimize_module
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_reshapes
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
dead_code_elimination
{},
simplify_reshapes
{},
simplify_algebra
{},
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
simplify_reshapes
{},
optimize_module
{},
propagate_constant
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
dead_code_elimination
{},
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
fuse_mlir
{
&
ctx
},
...
...
test/onnx/gather_dyn_test.onnx
0 → 100644
View file @
bb7f65d9
File added
test/onnx/gather_scalar_test.onnx
0 → 100644
View file @
bb7f65d9
File added
test/onnx/gen_onnx.py
View file @
bb7f65d9
...
@@ -2053,6 +2053,40 @@ def gather_test():
...
@@ -2053,6 +2053,40 @@ def gather_test():
return
([
node
],
[
x
,
i
],
[
y
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gather_scalar_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
4
,
5
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Gather'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
],
axis
=
1
,
)
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gather_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[
None
,
3
,
4
,
5
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
3
,
4
,
5
])
node
=
onnx
.
helper
.
make_node
(
'Gather'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
],
axis
=
1
,
)
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
gather_elements_axis0_test
():
def
gather_elements_axis0_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
])
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
])
...
...
test/onnx/onnx_test.cpp
View file @
bb7f65d9
...
@@ -2048,6 +2048,46 @@ TEST_CASE(gather_test)
...
@@ -2048,6 +2048,46 @@ TEST_CASE(gather_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gather_scalar_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
std
::
vector
<
size_t
>
idims
{
1
};
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
idims
,
{
0
}});
int
axis
=
1
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
l0
,
l1
);
auto
prog
=
optimize_onnx
(
"gather_scalar_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
5
,
5
,
0
},
{
6
,
6
,
0
}}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
5
,
5
,
0
}}});
auto
cont_l0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
cont_l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
int
axis
=
1
;
auto
gather_op
=
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}});
auto
ret
=
mm
->
add_instruction
(
gather_op
,
cont_l0
,
cont_l1
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
4
,
0
};
auto
prog
=
parse_onnx
(
"gather_dyn_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_elements_axis0_test
)
TEST_CASE
(
gather_elements_axis0_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/op_shape_test.cpp
View file @
bb7f65d9
...
@@ -831,6 +831,77 @@ TEST_CASE(gather)
...
@@ -831,6 +831,77 @@ TEST_CASE(gather)
}
}
}
}
TEST_CASE
(
gather_dyn0
)
{
// Insert dynamic index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
7
,
3
},
{
3
,
3
,
0
}}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
2
,
7
,
3
},
{
3
,
3
,
0
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn1
)
{
// Insert static index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
2
,
2
,
0
},
{
3
,
3
,
0
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn2
)
{
// Insert scalar (static) index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
std
::
vector
<
std
::
size_t
>
mins
;
std
::
vector
<
std
::
size_t
>
maxes
;
std
::
vector
<
std
::
size_t
>
opts
;
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
mins
,
maxes
,
opts
};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn3
)
{
// Insert dynamic index into static shape, axis 1
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
6
,
12
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
}}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
2
,
0
},
{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
6
,
0
},
{
12
,
12
,
0
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn4
)
{
// Insert dynamic index into static shape, axis 0
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
6
,
12
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
}}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
3
,
3
,
0
},
{
6
,
6
,
0
},
{
12
,
12
,
0
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
get_tuple_elem_test
)
TEST_CASE
(
get_tuple_elem_test
)
{
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
bool_type
,
{
1
,
1
}};
migraphx
::
shape
s0
{
migraphx
::
shape
::
bool_type
,
{
1
,
1
}};
...
...
test/ref_ops_test.cpp
View file @
bb7f65d9
...
@@ -2524,6 +2524,78 @@ TEST_CASE(gather_test)
...
@@ -2524,6 +2524,78 @@ TEST_CASE(gather_test)
}
}
}
}
TEST_CASE(gather_dyn_test0)
{
// Dynamic data, static indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {3, 3, 0}}};
auto x = mm->add_parameter("x", s);
std::vector<int> indices{1, 2};
migraphx::shape s_ind{migraphx::shape::int32_type, {1, 2}};
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5, 0}, {1, 1, 0}, {2, 2, 0}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::ref::target{});
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 3}};
migraphx::shape input_indices{migraphx::shape::int32_type, {1, 2}};
migraphx::parameter_map params;
std::vector<int> data(2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
params["indices"] = migraphx::argument(input_indices, indices.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5};
std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {2, 1, 2}};
EXPECT(result.get_shape() == sfinal);
}
TEST_CASE(gather_dyn_test1)
{
// Dynamic data, dynamic indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {4, 4, 0}}};
auto x = mm->add_parameter("x", s);
migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}}};
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}, {4, 4, 0}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::ref::target{});
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {3, 4}};
migraphx::shape input_indices_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{2, 0};
migraphx::parameter_map params;
std::vector<int> data(3 * 4);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
params["indices"] = migraphx::argument(input_indices_shape, indices.data());
auto result = p.eval(params).back();
std::vector<int> gold = {8, 9, 10, 11, 0, 1, 2, 3};
std::vector<int> results_vector(1 * 2 * 4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {1, 2, 4}};
EXPECT(result.get_shape() == sfinal);
}
TEST_CASE(gathernd_test)
TEST_CASE(gathernd_test)
{
{
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment