Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b4bba9a0
Commit
b4bba9a0
authored
Jan 25, 2023
by
Paul
Browse files
Format
parent
94d93226
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
31 deletions
+48
-31
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+26
-17
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+22
-14
No files found.
src/fuse_reduce.cpp
View file @
b4bba9a0
...
@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
...
@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
return
result
;
return
result
;
}
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
for
(
auto
input
:
ins
->
inputs
())
{
{
if
(
contains
(
map_ins
,
input
))
if
(
contains
(
map_ins
,
input
))
continue
;
continue
;
...
@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map
...
@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map
}
}
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
insert_params
(
sm
,
ins
,
map_ins
);
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
...
@@ -115,23 +119,27 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
...
@@ -115,23 +119,27 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
insert_params
(
sm
,
ins
,
map_ins
);
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
}
return
sm
->
add_instructions
(
m
,
map_ins
);
return
sm
->
add_instructions
(
m
,
map_ins
);
}
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
std
::
vector
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
{
if
(
not
sm
->
has_instruction
(
param
))
if
(
not
sm
->
has_instruction
(
param
))
continue
;
continue
;
...
@@ -182,7 +190,8 @@ struct find_pointwise_reduce
...
@@ -182,7 +190,8 @@ struct find_pointwise_reduce
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
@@ -235,7 +244,7 @@ struct find_reduce_pointwise
...
@@ -235,7 +244,7 @@ struct find_reduce_pointwise
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
}
};
};
}
}
// namespace
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
...
src/targets/gpu/compile_gen.cpp
View file @
b4bba9a0
...
@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name)
...
@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name)
i
++
;
i
++
;
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
std
::
vector
<
instruction_ref
>
tensors
;
std
::
vector
<
instruction_ref
>
tensors
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
std
::
copy_if
(
ins
->
inputs
().
begin
(),
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
});
auto
inner_names
=
names
;
auto
inner_names
=
names
;
for
(
auto
input
:
tensors
)
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
pointwise_name
+
"("
+
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
if
(
tensors
.
empty
())
return
call_function
;
return
call_function
;
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
std
::
transform
(
return
"auto "
+
s
;
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
});
return
interpolate_string
(
inner_template
,
return
interpolate_string
(
inner_template
,
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment