Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
94d93226
Commit
94d93226
authored
Jan 25, 2023
by
Paul
Browse files
Fuse pointwise before as well
parent
50b471d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
136 additions
and
49 deletions
+136
-49
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+116
-44
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+19
-5
src/targets/gpu/jit/fused_reduce.cpp
src/targets/gpu/jit/fused_reduce.cpp
+1
-0
No files found.
src/fuse_reduce.cpp
View file @
94d93226
...
@@ -33,6 +33,7 @@
...
@@ -33,6 +33,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
#include <iterator>
#include <map>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -73,6 +74,79 @@ struct fused_reduce
...
@@ -73,6 +74,79 @@ struct fused_reduce
};
};
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
MIGRAPHX_REGISTER_OP
(
fused_reduce
);
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
void
insert_params
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
auto
n
=
sm
->
get_parameter_shapes
().
size
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
// TODO: Ensure standard shape
map_ins
[
input
]
=
sm
->
add_parameter
(
"x"
+
std
::
to_string
(
n
++
),
input
->
get_shape
());
}
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
return
sm
->
add_instructions
({
ins
},
map_ins
);
}
static
auto
insert_ins_in_submodule
(
module_ref
sm
,
instruction_ref
ins
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
return
insert_ins_in_submodule
(
sm
,
ins
,
map_ins
);
}
static
auto
insert_module_in_submodule
(
module_ref
sm
,
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
insert_params
(
sm
,
ins
,
map_ins
);
auto
*
m
=
ins
->
module_inputs
().
front
();
auto
param_map
=
get_ins_param_map
(
ins
->
inputs
(),
m
);
for
(
auto
&&
[
input
,
param
]
:
param_map
)
{
map_ins
[
param
]
=
map_ins
.
at
(
input
);
}
return
sm
->
add_instructions
(
m
,
map_ins
);
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
for
(
auto
&&
[
input
,
param
]
:
map_ins
)
{
if
(
not
sm
->
has_instruction
(
param
))
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
}
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
});
return
result
;
}
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
...
@@ -87,10 +161,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
...
@@ -87,10 +161,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
rm
->
set_bypass
();
// TODO: Ensure standard shape
rm
->
add_return
(
insert_ins_in_submodule
(
rm
,
ins
));
auto
x0
=
rm
->
add_parameter
(
"x0"
,
ins
->
inputs
().
front
()
->
get_shape
());
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
rm
->
add_return
({
r
});
auto
v
=
ins
->
get_operator
().
to_value
();
auto
v
=
ins
->
get_operator
().
to_value
();
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
...
@@ -98,23 +169,6 @@ static void create_reduce_modules(module_pass_manager& mpm)
...
@@ -98,23 +169,6 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}
}
}
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_ins_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
{
{
auto
last
=
std
::
prev
(
m
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
...
@@ -123,6 +177,36 @@ static std::vector<instruction_ref> get_returns(module& m)
...
@@ -123,6 +177,36 @@ static std::vector<instruction_ref> get_returns(module& m)
return
{
last
};
return
{
last
};
}
}
namespace
{
struct
find_pointwise_reduce
{
auto
matcher
()
const
{
return
match
::
name
(
"fused_reduce"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
()).
bind
(
"pointwise"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
reduce
=
r
.
result
;
auto
pw
=
r
.
instructions
[
"pointwise"
];
const
auto
*
pm
=
pw
->
module_inputs
().
front
();
// const auto* old_rm = reduce->module_inputs().front();
auto
*
rm
=
mpm
.
create_module
(
pm
->
name
()
+
":reduce"
);
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Insert pointwise
auto
rins
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
).
front
();
map_ins
[
pw
]
=
rins
;
// Insert fused_reduce
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
struct
find_reduce_pointwise
struct
find_reduce_pointwise
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -133,43 +217,31 @@ struct find_reduce_pointwise
...
@@ -133,43 +217,31 @@ struct find_reduce_pointwise
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
pw
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
reduce
=
r
.
instructions
[
"reduce"
];
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
const
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
rm
->
set_bypass
();
rm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// Copy module instructions
// Copy module instructions
rm
->
add_instructions
(
old_rm
);
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
auto
map_ins
=
get_ins_param_map
(
reduce
->
inputs
(),
rm
);
map_ins
[
reduce
]
=
get_returns
(
*
rm
).
front
();
auto
new_inputs
=
reduce
->
inputs
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
if
(
input
==
reduce
)
{
map_ins
[
input
]
=
get_returns
(
*
rm
).
front
();
}
else
{
map_ins
[
input
]
=
rm
->
add_parameter
(
"x"
+
std
::
to_string
(
new_inputs
.
size
()),
input
->
get_shape
());
new_inputs
.
push_back
(
input
);
}
}
auto
out
=
rm
->
add_instructions
({
ins
},
map_ins
);
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
add_return
(
out
);
rm
->
replace_return
(
out
);
mpm
.
get_module
().
replace_instruction
(
ins
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
}
};
};
}
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
create_reduce_modules
(
mpm
);
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{}
,
find_pointwise_reduce
{}
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
}
}
...
...
src/targets/gpu/compile_gen.cpp
View file @
94d93226
...
@@ -260,13 +260,11 @@ struct reduce_op
...
@@ -260,13 +260,11 @@ struct reduce_op
}
}
};
};
// const std::string& generate_reduce_body = R"__migraphx__(
// )__migraphx__";
std
::
string
generate_reduce
(
const
module
&
rm
,
const
std
::
string
&
name
)
std
::
string
generate_reduce
(
const
module
&
rm
,
const
std
::
string
&
name
)
{
{
module
m
=
rm
;
module
m
=
rm
;
cpp_generator
g
;
cpp_generator
g
;
auto
ilens
=
rm
.
get_parameter_shapes
().
begin
()
->
second
.
lens
();
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
auto
f
=
g
.
generate_module
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
{
auto
f
=
g
.
generate_module
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
{
if
(
contains
(
ins
->
name
(),
"reduce"
))
if
(
contains
(
ins
->
name
(),
"reduce"
))
...
@@ -278,8 +276,24 @@ std::string generate_reduce(const module& rm, const std::string& name)
...
@@ -278,8 +276,24 @@ std::string generate_reduce(const module& rm, const std::string& name)
auto
pointwise_name
=
"pointwise"
+
std
::
to_string
(
i
);
auto
pointwise_name
=
"pointwise"
+
std
::
to_string
(
i
);
i
++
;
i
++
;
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
return
pointwise_name
+
"("
+
std
::
vector
<
instruction_ref
>
tensors
;
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
names
),
", "
)
+
")"
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
auto
inner_names
=
names
;
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
return
call_function
;
const
std
::
string
inner_template
=
"r.inner([=](${params}) { return ${call}; })(${args})"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
return
interpolate_string
(
inner_template
,
{{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
});
...
...
src/targets/gpu/jit/fused_reduce.cpp
View file @
94d93226
...
@@ -37,6 +37,7 @@ using namespace migraphx::gpu::gen; // NOLINT
...
@@ -37,6 +37,7 @@ using namespace migraphx::gpu::gen; // NOLINT
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment