Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c5de9766
Commit
c5de9766
authored
Jan 20, 2023
by
Paul
Browse files
Update funtion name
parent
6bc30259
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
30 deletions
+25
-30
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+25
-30
No files found.
src/fuse_reduce.cpp
View file @
c5de9766
...
@@ -52,15 +52,15 @@ struct fused_reduce
...
@@ -52,15 +52,15 @@ struct fused_reduce
{
{
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
}
}
auto
*
sm
=
mods
.
front
();
auto
*
sm
=
mods
.
front
();
check_shapes
{
inputs
,
*
this
}.
has
(
sm
->
get_parameter_shapes
().
size
()).
same_dims
();
check_shapes
{
inputs
,
*
this
}.
has
(
sm
->
get_parameter_shapes
().
size
()).
same_dims
();
auto
s
=
inputs
.
at
(
0
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
for
(
const
auto
&
axis
:
axes
)
{
{
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
}
}
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
MIGRAPHX_THROW
(
"Only one output supported"
);
return
inputs
[
0
].
with_lens
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
);
return
inputs
[
0
].
with_lens
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
);
}
}
...
@@ -75,16 +75,15 @@ static void create_reduce_modules(module_pass_manager& mpm)
...
@@ -75,16 +75,15 @@ static void create_reduce_modules(module_pass_manager& mpm)
{
{
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
continue
;
continue
;
if
(
ins
->
inputs
().
size
()
!=
1
)
if
(
ins
->
inputs
().
size
()
!=
1
)
continue
;
continue
;
auto
*
rm
=
auto
*
rm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
rm
->
set_bypass
();
// TODO: Ensure standard shape
// TODO: Ensure standard shape
auto
x0
=
rm
->
add_parameter
(
"x0"
,
ins
->
inputs
().
front
()
->
get_shape
());
auto
x0
=
rm
->
add_parameter
(
"x0"
,
ins
->
inputs
().
front
()
->
get_shape
());
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
rm
->
add_return
({
r
});
rm
->
add_return
({
r
});
// TODO: Set axes
// TODO: Set axes
...
@@ -92,27 +91,22 @@ static void create_reduce_modules(module_pass_manager& mpm)
...
@@ -92,27 +91,22 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}
}
}
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
get_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
names
.
end
(),
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
inputs
.
begin
(),
});
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
return
result
;
}
}
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
{
{
auto
last
=
std
::
prev
(
m
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
if
(
last
->
name
()
==
"@return"
)
if
(
last
->
name
()
==
"@return"
)
return
last
->
inputs
();
return
last
->
inputs
();
return
{
last
};
return
{
last
};
}
}
...
@@ -121,32 +115,33 @@ struct find_reduce_pointwise
...
@@ -121,32 +115,33 @@ struct find_reduce_pointwise
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
// Copy module
// Copy module
*
rm
=
*
old_rm
;
*
rm
=
*
old_rm
;
auto
map_ins
=
get_param_map
(
reduce
->
inputs
(),
rm
);
auto
map_ins
=
get_
ins_
param_map
(
reduce
->
inputs
(),
rm
);
auto
new_inputs
=
reduce
->
inputs
();
auto
new_inputs
=
reduce
->
inputs
();
for
(
auto
input
:
ins
->
inputs
())
for
(
auto
input
:
ins
->
inputs
())
{
{
if
(
contains
(
map_ins
,
input
))
if
(
contains
(
map_ins
,
input
))
continue
;
continue
;
if
(
input
==
reduce
)
if
(
input
==
reduce
)
{
{
map_ins
[
input
]
=
rm
->
map_ins
[
input
]
=
get_returns
(
*
rm
).
front
();
}
else
{
map_ins
[
input
]
=
rm
->
add_parameter
(
"x"
+
std
::
to_string
(
new_inputs
.
size
()),
input
->
get_shape
());
new_inputs
.
push_back
(
input
);
}
}
map_ins
[
input
]
=
rm
->
add_parameter
(
"x"
+
std
::
to_string
(
new_inputs
.
size
()),
input
->
get_shape
());
new_inputs
.
push_back
(
input
);
}
}
auto
out
=
rm
->
insert_instructions
(
std
::
prev
(
rm
->
end
()),
{
ins
},
map_ins
);
auto
out
=
rm
->
insert_instructions
(
std
::
prev
(
rm
->
end
()),
{
ins
},
map_ins
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment