Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3d9bef13
Commit
3d9bef13
authored
Sep 16, 2018
by
Paul
Browse files
Move functions to cpp file
parent
b263425a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
191 additions
and
135 deletions
+191
-135
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+29
-135
src/instruction.cpp
src/instruction.cpp
+161
-0
No files found.
src/CMakeLists.txt
View file @
3d9bef13
...
...
@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp
env.cpp
generate.cpp
instruction.cpp
program.cpp
shape.cpp
simplify_reshapes.cpp
...
...
src/include/migraph/instruction.hpp
View file @
3d9bef13
...
...
@@ -3,10 +3,8 @@
#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <utility>
...
...
@@ -18,156 +16,61 @@ struct instruction
{
instruction
()
{}
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
{
}
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
);
void
replace
(
const
shape
&
r
)
{
if
(
r
!=
result
)
{
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
instruction
(
literal
l
);
void
re
compute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
re
place
(
const
shape
&
r
);
void
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
arg
->
remove_output
(
*
this
);
}
arguments
.
clear
();
}
friend
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
void
recompute_shape
();
bool
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
*
this
);
return
self
!=
i
->
outputs
().
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
void
clear_arguments
();
friend
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
);
bool
valid
()
const
{
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
*
this
)
!=
i
->
inputs
().
end
();
});
}
bool
valid
(
instruction_ref
start
)
const
;
shape
get_shape
()
const
{
return
result
;
}
const
literal
&
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
bool
valid
()
const
;
shape
get_shape
()
const
;
const
literal
&
get_literal
()
const
;
const
operation
&
get_operator
()
const
{
return
op
;
}
const
operation
&
get_operator
()
const
;
std
::
string
name
()
const
{
return
op
.
name
();
}
std
::
string
name
()
const
;
const
std
::
vector
<
instruction_ref
>&
inputs
()
const
{
return
arguments
;
}
const
std
::
vector
<
instruction_ref
>&
inputs
()
const
;
const
std
::
vector
<
instruction_ref
>&
outputs
()
const
{
return
output
;
}
const
std
::
vector
<
instruction_ref
>&
outputs
()
const
;
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
;
friend
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
;
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
;
void
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
void
add_output
(
instruction_ref
ins
);
template
<
class
T
>
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
remove_output
(
const
T
&
ins
);
static
void
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
inputs
())
arg
->
add_output
(
ref
);
}
static
void
backreference
(
instruction_ref
ref
);
static
void
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
static
void
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
);
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
ins
->
replace
(
std
::
move
(
o
),
r
,
std
::
move
(
args
));
backreference
(
ins
);
}
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
private:
// internal
void
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
op
=
std
::
move
(
o
);
replace
(
r
);
replace
(
std
::
move
(
args
));
}
void
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
// internal
void
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
replace
(
std
::
vector
<
instruction_ref
>
args
);
// internal
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
);
operation
op
;
shape
result
;
std
::
vector
<
instruction_ref
>
output
;
...
...
@@ -175,15 +78,6 @@ struct instruction
literal
lit
;
};
// TODO: Move to a cpp file
inline
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
op
.
compute_shape
(
shapes
);
}
}
// namespace migraph
namespace
std
{
...
...
src/instruction.cpp
0 → 100644
View file @
3d9bef13
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace
migraph
{
instruction
::
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
{
}
instruction
::
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
void
instruction
::
replace
(
const
shape
&
r
)
{
if
(
r
!=
result
)
{
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
arg
->
remove_output
(
*
this
);
}
arguments
.
clear
();
}
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
bool
instruction
::
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
*
this
);
return
self
!=
i
->
outputs
().
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
instruction
::
valid
()
const
{
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
*
this
)
!=
i
->
inputs
().
end
();
});
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
const
operation
&
instruction
::
get_operator
()
const
{
return
op
;
}
std
::
string
instruction
::
name
()
const
{
return
op
.
name
();
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
inputs
()
const
{
return
arguments
;
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
outputs
()
const
{
return
output
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
instruction
::
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
instruction
::
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
inputs
())
arg
->
add_output
(
ref
);
}
void
instruction
::
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
void
instruction
::
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
ins
->
replace
(
std
::
move
(
o
),
r
,
std
::
move
(
args
));
backreference
(
ins
);
}
void
instruction
::
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
op
=
std
::
move
(
o
);
replace
(
r
);
replace
(
std
::
move
(
args
));
}
void
instruction
::
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
instruction
::
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
op
.
compute_shape
(
shapes
);
}
}
// namespace migraph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment