Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
94d93226
Commit
94d93226
authored
Jan 25, 2023
by
Paul
Browse files
Fuse pointwise before as well
parent
50b471d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
136 additions
and
49 deletions
+136
-49
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+116
-44
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+19
-5
src/targets/gpu/jit/fused_reduce.cpp
src/targets/gpu/jit/fused_reduce.cpp
+1
-0
No files found.
src/fuse_reduce.cpp
View file @
94d93226
...
...
@@ -33,6 +33,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
#include <map>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -73,6 +74,79 @@ struct fused_reduce
};
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
// TODO: Ensure standard shape
map_ins
[
input
]
=
sm
->
add_parameter
(
"x"
+
std
::
to_string
(
n
++
),
input
->
get_shape
());
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
return
sm
->
add_instructions
(
m
,
map_ins
);
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
if
(
not
sm
->
has_instruction
(
param
))
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
}
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
});
return
result
;
}
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
std
::
size_t
n
=
0
;
...
...
@@ -87,10 +161,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
// TODO: Ensure standard shape
auto
x0
=
rm
->
add_parameter
(
"x0"
,
ins
->
inputs
().
front
()
->
get_shape
());
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
rm
->
add_return
({
r
});
rm
->
add_return
(
insert_ins_in_submodule
(
rm
,
ins
));
auto
v
=
ins
->
get_operator
().
to_value
();
mpm
.
get_module
().
replace_instruction
(
...
...
@@ -98,23 +169,6 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
{
auto
last
=
std
::
prev
(
m
.
end
());
...
...
@@ -123,6 +177,36 @@ static std::vector<instruction_ref> get_returns(module& m)
return
{
last
};
}
namespace
{
struct
find_pointwise_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce
=
r
.
result
;
auto
pw
=
r
.
instructions
[
"pointwise"
];
const
auto
*
pm
=
pw
->
module_inputs
().
front
();
// const auto* old_rm = reduce->module_inputs().front();
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":reduce"
);
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Insert pointwise
auto
rins
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
).
front
();
map_ins
[
pw
]
=
rins
;
// Insert fused_reduce
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_pointwise
{
auto
matcher
()
const
...
...
@@ -133,43 +217,31 @@ struct find_reduce_pointwise
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
rm
->
add_instructions
(
old_rm
);
auto
map_ins
=
get_ins_param_map
(
reduce
->
inputs
(),
rm
);
auto
new_inputs
=
reduce
->
inputs
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
if
(
input
==
reduce
)
{
map_ins
[
input
]
=
get_returns
(
*
rm
).
front
();
}
else
{
map_ins
[
input
]
=
rm
->
add_parameter
(
"x"
+
std
::
to_string
(
new_inputs
.
size
()),
input
->
get_shape
());
new_inputs
.
push_back
(
input
);
}
}
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
map_ins
[
reduce
]
=
get_returns
(
*
rm
).
front
();
auto
out
=
rm
->
add_instructions
({
ins
},
map_ins
);
rm
->
add_return
(
out
);
mpm
.
get_module
().
replace_instruction
(
ins
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
}
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{}
,
find_pointwise_reduce
{}
);
mpm
.
run_pass
(
dead_code_elimination
{});
}
...
...
src/targets/gpu/compile_gen.cpp
View file @
94d93226
...
...
@@ -260,13 +260,11 @@ struct reduce_op
}
};
// const std::string& generate_reduce_body = R"__migraphx__(
// )__migraphx__";
std
::
string
generate_reduce
(
const
module
&
rm
,
const
std
::
string
&
name
)
{
module
m
=
rm
;
cpp_generator
g
;
auto
ilens
=
rm
.
get_parameter_shapes
().
begin
()
->
second
.
lens
();
std
::
size_t
i
=
0
;
auto
f
=
g
.
generate_module
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
{
if
(
contains
(
ins
->
name
(),
"reduce"
))
...
...
@@ -278,8 +276,24 @@ std::string generate_reduce(const module& rm, const std::string& name)
auto
pointwise_name
=
"pointwise"
+
std
::
to_string
(
i
);
i
++
;
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
return
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
names
),
", "
)
+
")"
;
std
::
vector
<
instruction_ref
>
tensors
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
auto
inner_names
=
names
;
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
return
call_function
;
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
return
interpolate_string
(
inner_template
,
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
...
...
src/targets/gpu/jit/fused_reduce.cpp
View file @
94d93226
...
...
@@ -37,6 +37,7 @@ using namespace migraphx::gpu::gen; // NOLINT
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment