Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3d9bef13
Commit
3d9bef13
authored
Sep 16, 2018
by
Paul
Browse files
Move functions to cpp file
parent
b263425a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
191 additions
and
135 deletions
+191
-135
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+29
-135
src/instruction.cpp
src/instruction.cpp
+161
-0
No files found.
src/CMakeLists.txt
View file @
3d9bef13
...
...
@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp
env.cpp
generate.cpp
instruction.cpp
program.cpp
shape.cpp
simplify_reshapes.cpp
...
...
src/include/migraph/instruction.hpp
View file @
3d9bef13
...
...
@@ -3,10 +3,8 @@
#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <utility>
...
...
@@ -18,156 +16,61 @@ struct instruction
{
instruction
()
{}
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
{
}
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
);
void
replace
(
const
shape
&
r
)
{
if
(
r
!=
result
)
{
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
instruction
(
literal
l
);
void
re
compute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
re
place
(
const
shape
&
r
);
void
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
arg
->
remove_output
(
*
this
);
}
arguments
.
clear
();
}
friend
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
void
recompute_shape
();
bool
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
*
this
);
return
self
!=
i
->
outputs
().
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
void
clear_arguments
();
friend
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
);
bool
valid
()
const
{
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
*
this
)
!=
i
->
inputs
().
end
();
});
}
bool
valid
(
instruction_ref
start
)
const
;
shape
get_shape
()
const
{
return
result
;
}
const
literal
&
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
bool
valid
()
const
;
shape
get_shape
()
const
;
const
literal
&
get_literal
()
const
;
const
operation
&
get_operator
()
const
{
return
op
;
}
const
operation
&
get_operator
()
const
;
std
::
string
name
()
const
{
return
op
.
name
();
}
std
::
string
name
()
const
;
const
std
::
vector
<
instruction_ref
>&
inputs
()
const
{
return
arguments
;
}
const
std
::
vector
<
instruction_ref
>&
inputs
()
const
;
const
std
::
vector
<
instruction_ref
>&
outputs
()
const
{
return
output
;
}
const
std
::
vector
<
instruction_ref
>&
outputs
()
const
;
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
;
friend
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
;
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
;
void
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
void
add_output
(
instruction_ref
ins
);
template
<
class
T
>
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
remove_output
(
const
T
&
ins
);
static
void
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
inputs
())
arg
->
add_output
(
ref
);
}
static
void
backreference
(
instruction_ref
ref
);
static
void
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
static
void
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
);
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
ins
->
replace
(
std
::
move
(
o
),
r
,
std
::
move
(
args
));
backreference
(
ins
);
}
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
private:
// internal
void
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
op
=
std
::
move
(
o
);
replace
(
r
);
replace
(
std
::
move
(
args
));
}
void
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
// internal
void
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
replace
(
std
::
vector
<
instruction_ref
>
args
);
// internal
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
);
operation
op
;
shape
result
;
std
::
vector
<
instruction_ref
>
output
;
...
...
@@ -175,15 +78,6 @@ struct instruction
literal
lit
;
};
// TODO: Move to a cpp file
inline
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
op
.
compute_shape
(
shapes
);
}
}
// namespace migraph
namespace
std
{
...
...
src/instruction.cpp
0 → 100644
View file @
3d9bef13
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace
migraph
{
instruction
::
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
{
}
instruction
::
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
void
instruction
::
replace
(
const
shape
&
r
)
{
if
(
r
!=
result
)
{
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
arg
->
remove_output
(
*
this
);
}
arguments
.
clear
();
}
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
bool
instruction
::
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
*
this
);
return
self
!=
i
->
outputs
().
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
instruction
::
valid
()
const
{
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
*
this
)
!=
i
->
inputs
().
end
();
});
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
const
operation
&
instruction
::
get_operator
()
const
{
return
op
;
}
std
::
string
instruction
::
name
()
const
{
return
op
.
name
();
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
inputs
()
const
{
return
arguments
;
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
outputs
()
const
{
return
output
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
instruction
::
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
instruction
::
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
inputs
())
arg
->
add_output
(
ref
);
}
void
instruction
::
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
void
instruction
::
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
ins
->
replace
(
std
::
move
(
o
),
r
,
std
::
move
(
args
));
backreference
(
ins
);
}
void
instruction
::
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
op
=
std
::
move
(
o
);
replace
(
r
);
replace
(
std
::
move
(
args
));
}
void
instruction
::
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
instruction
::
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
op
.
compute_shape
(
shapes
);
}
}
// namespace migraph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment