Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b4bba9a0
Commit
b4bba9a0
authored
Jan 25, 2023
by
Paul
Browse files
Format
parent
94d93226
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
31 deletions
+48
-31
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+26
-17
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+22
-14
No files found.
src/fuse_reduce.cpp
View file @
b4bba9a0
...
...
@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
return
result
;
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
...
...
@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
...
...
@@ -115,30 +119,34 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
return
sm
->
add_instructions
(
m
,
map_ins
);
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
if
(
not
sm
->
has_instruction
(
param
))
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
}
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
...
...
@@ -182,22 +190,23 @@ struct find_pointwise_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce
=
r
.
result
;
auto
pw
=
r
.
instructions
[
"pointwise"
];
auto
reduce
=
r
.
result
;
auto
pw
=
r
.
instructions
[
"pointwise"
];
const
auto
*
pm
=
pw
->
module_inputs
().
front
();
// const auto* old_rm = reduce->module_inputs().front();
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":reduce"
);
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":reduce"
);
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Insert pointwise
auto
rins
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
).
front
();
auto
rins
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
).
front
();
map_ins
[
pw
]
=
rins
;
// Insert fused_reduce
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
...
...
@@ -217,7 +226,7 @@ struct find_reduce_pointwise
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
pw
=
r
.
result
;
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
...
...
@@ -235,7 +244,7 @@ struct find_reduce_pointwise
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
...
...
src/targets/gpu/compile_gen.cpp
View file @
b4bba9a0
...
...
@@ -264,7 +264,7 @@ std::string generate_reduce(const module& rm, const std::string& name)
{
module
m
=
rm
;
cpp_generator
g
;
auto
ilens
=
rm
.
get_parameter_shapes
().
begin
()
->
second
.
lens
();
auto
ilens
=
rm
.
get_parameter_shapes
().
begin
()
->
second
.
lens
();
std
::
size_t
i
=
0
;
auto
f
=
g
.
generate_module
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
{
if
(
contains
(
ins
->
name
(),
"reduce"
))
...
...
@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name)
i
++
;
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
std
::
vector
<
instruction_ref
>
tensors
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
auto
inner_names
=
names
;
for
(
auto
input
:
tensors
)
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
return
call_function
;
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
return
interpolate_string
(
inner_template
,
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
return
interpolate_string
(
inner_template
,
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment