Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4e55c401
Commit
4e55c401
authored
Jan 20, 2023
by
Paul
Browse files
Some fixes
parent
40762b08
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
15 deletions
+20
-15
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+19
-14
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+1
-1
No files found.
src/fuse_reduce.cpp
View file @
4e55c401
...
...
@@ -31,6 +31,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
namespace
migraphx
{
...
...
@@ -49,24 +50,27 @@ struct fused_reduce
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"should have one submodule."
);
}
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
check_shapes
{
inputs
,
*
this
}.
has
(
sm
->
get_parameter_shapes
().
size
()).
same_dims
();
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
if
(
lens
!=
sm
->
get_output_shapes
().
front
().
lens
()
)
{
lens
[
axis
]
=
1
;
for
(
const
auto
&
axis
:
axes
)
{
lens
[
axis
]
=
1
;
}
}
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
return
inputs
[
0
].
with_lens
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
);
return
shape
::
from_permutation
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
,
find_permutation
(
inputs
));
}
std
::
string
name
()
const
{
return
"fused_reduce"
;
}
};
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
...
...
@@ -87,8 +91,8 @@ static void create_reduce_modules(module_pass_manager& mpm)
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
rm
->
add_return
({
r
});
// TODO: Set axes
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_reduce"
),
ins
->
inputs
(),
{
rm
});
auto
v
=
ins
->
get_operator
().
to_value
();
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_reduce"
,
{{
"axes"
,
v
[
"axes"
]}}
),
ins
->
inputs
(),
{
rm
});
}
}
...
...
@@ -130,10 +134,11 @@ struct find_reduce_pointwise
auto
ins
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
// Copy module
*
rm
=
*
old_rm
;
rm
->
set_bypass
();
// Copy module instructions
rm
->
add_instructions
(
old_rm
);
auto
map_ins
=
get_ins_param_map
(
reduce
->
inputs
(),
rm
);
auto
new_inputs
=
reduce
->
inputs
();
for
(
auto
input
:
ins
->
inputs
())
...
...
@@ -152,8 +157,8 @@ struct find_reduce_pointwise
}
}
auto
out
=
rm
->
insert
_instructions
(
std
::
prev
(
rm
->
end
()),
{
ins
},
map_ins
);
rm
->
replace
_return
(
out
);
auto
out
=
rm
->
add
_instructions
({
ins
},
map_ins
);
rm
->
add
_return
(
out
);
mpm
.
get_module
().
replace_instruction
(
ins
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
4e55c401
...
...
@@ -116,7 +116,7 @@ struct find_add_layernorm
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
//
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
}
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment