Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a2092da6
Commit
a2092da6
authored
Oct 30, 2018
by
Scott Thornton
Browse files
Added optimization pass to remove concat operator when appropriate
parent
ceaf5ee0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
312 additions
and
1 deletion
+312
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+70
-0
src/include/migraph/eliminate_concat.hpp
src/include/migraph/eliminate_concat.hpp
+21
-0
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+9
-1
test/eliminate_concat_test.cpp
test/eliminate_concat_test.cpp
+166
-0
tools/include/concat_opt.hpp
tools/include/concat_opt.hpp
+45
-0
No files found.
src/CMakeLists.txt
View file @
a2092da6
...
@@ -6,6 +6,7 @@ add_library(migraph
...
@@ -6,6 +6,7 @@ add_library(migraph
dead_code_elimination.cpp
dead_code_elimination.cpp
eliminate_allocation.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp
env.cpp
generate.cpp
generate.cpp
...
...
src/eliminate_concat.cpp
0 → 100644
View file @
a2092da6
#include <iterator>
#include <migraph/eliminate_concat.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace
migraph
{
void
eliminate_concat
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
// Look for the concat operator
if
(
ins
->
name
()
!=
concat_opt
.
name
())
continue
;
// If any inputs are literals then abort
if
(
std
::
any_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
name
()
==
"@literal"
;
}))
continue
;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
if
(
concat_op
.
axis
==
0
||
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
concat_op
.
axis
,
[]
(
auto
x
)
{
return
x
==
1
;
}))
{
// Last input should be an allocation
auto
last
=
ins
->
inputs
().
back
();
if
(
last
->
name
()
!=
concat_opt
.
allocate
())
continue
;
// Where are the allocations for the tensors to be concatenated?
std
::
vector
<
instruction_ref
>
allocations
;
for
(
auto
ins2
=
ins
->
inputs
().
begin
();
ins2
!=
ins
->
inputs
().
end
()
-
1
;
ins2
++
)
{
auto
last2
=
(
*
ins2
)
->
inputs
().
back
();
if
(
last2
->
name
()
==
concat_opt
.
allocate
())
{
allocations
.
push_back
(
last2
);
}
}
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std
::
sort
(
allocations
.
begin
(),
allocations
.
end
(),
[
&
]
(
instruction_ref
x
,
instruction_ref
y
)
{
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
});
// Move "super" allocation to the front
auto
first
=
allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
std
::
size_t
offset
=
0
;
for
(
auto
x
:
allocations
)
{
migraph
::
op
::
load
op
{
x
->
get_shape
(),
offset
};
p
.
replace_instruction
(
x
,
op
,
{
super
});
offset
+=
x
->
get_shape
().
elements
();
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraph
::
op
::
identity
{},
args
);
}
}
}
}
// namespace migraph
src/include/migraph/eliminate_concat.hpp
0 → 100644
View file @
a2092da6
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/concat_opt.hpp>
namespace
migraph
{
struct
program
;
struct
eliminate_concat
{
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/migraph/operators.hpp
View file @
a2092da6
...
@@ -617,9 +617,17 @@ struct unary
...
@@ -617,9 +617,17 @@ struct unary
}
}
};
};
struct
identity
:
unary
struct
identity
{
{
std
::
string
name
()
const
{
return
"identity"
;
}
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
};
};
struct
abs
:
unary
struct
abs
:
unary
...
...
test/eliminate_concat_test.cpp
0 → 100644
View file @
a2092da6
#include <migraph/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
concat
{
concat
(
std
::
size_t
axis
)
{
op
.
axis
=
axis
;
}
migraph
::
op
::
concat
op
;
std
::
string
name
()
const
{
return
"eliminate_concat::concat"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
migraph
::
argument
compute
(
migraph
::
context
&
ctx
,
const
migraph
::
shape
&
output_shape
,
const
std
::
vector
<
migraph
::
argument
>&
args
)
const
{
return
{
output_shape
};
}
};
struct
concat_test_optimization
{
/// A unique name used to identify the concat optimization
std
::
string
name
()
const
{
return
"eliminate_concat::concat"
;
}
/// A unique name used to identify the allocate operator
std
::
string
allocate
()
const
{
return
"allocate"
;
}
/// Return the lowered concat operator
migraph
::
op
::
concat
get_concat
(
const
migraph
::
operation
&
op
)
const
{
return
migraph
::
any_cast
<
concat
>
(
op
).
op
;
}
};
struct
eliminate_concat_target
{
std
::
size_t
align
=
32
;
std
::
string
name
()
const
{
return
"eliminate_target"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
eliminate_concat
{
concat_test_optimization
{}},
migraph
::
dead_code_elimination
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
struct
allocate
{
migraph
::
shape
s
{};
std
::
string
name
()
const
{
return
"allocate"
;
}
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>&
inputs
)
const
{
migraph
::
check_shapes
{
inputs
}.
has
(
0
);
return
s
;
}
migraph
::
argument
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
output_shape
,
const
std
::
vector
<
migraph
::
argument
>&
)
const
{
return
{
output_shape
};
}
};
struct
fred_op
{
std
::
string
name
()
const
{
return
"fred_op"
;
}
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>&
inputs
)
const
{
migraph
::
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
}
migraph
::
argument
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
output_shape
,
const
std
::
vector
<
migraph
::
argument
>&
args
)
const
{
return
args
.
at
(
0
);
}
};
void
basic
()
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
};
auto
create_control_program
=
[]()
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
l1
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}},
0
},
{
a1
});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
l1
);
auto
l2
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}},
128
},
{
a1
});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
l2
);
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}},
320
},
{
a1
});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
l3
);
auto
i1
=
p
.
add_instruction
(
migraph
::
op
::
identity
{},
{
a1
,
p1
,
p2
,
p3
});
return
p
;
};
auto
p1
=
create_test_program
();
auto
p2
=
create_control_program
();
p1
.
compile
(
eliminate_concat_target
{});
EXPECT
(
p1
==
p2
);
}
void
wont_work
()
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
};
auto
create_control_program
=
[]()
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
};
auto
p1
=
create_test_program
();
auto
p2
=
create_control_program
();
p1
.
compile
(
eliminate_concat_target
{});
EXPECT
(
p1
==
p2
);
}
int
main
()
{
setenv
(
"MIGRAPH_DISABLE_MEMORY_COLORING"
,
"1"
,
1
);
basic
();
wont_work
();
}
tools/include/concat_opt.hpp
0 → 100644
View file @
a2092da6
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace
migraph
{
struct
program
;
#ifdef DOXYGEN
/// An interface for applying an optimization for the concat instruction
struct
concat_optimization
{
/// A unique name used to identify the concat optimization
std
::
string
name
()
const
;
/// A unique name used to identify the allocate operator
std
::
string
allocate
()
const
;
/// Return the lowered concat operator
op
::
concat
get_concat
(
const
operation
&
op
)
const
;
};
#else
<%
interface
(
'
concat_optimization
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
allocate
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
get_concat
'
,
returns
=
'
op
::
concat
'
,
op
=
'
const
operation
&
'
,
const
=
True
)
)
%>
#endif
}
// namespace migraph
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment