Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2d5fd19
Commit
d2d5fd19
authored
Aug 11, 2018
by
Paul
Browse files
Fix test for replacements
parent
64370f87
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
17 deletions
+95
-17
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-3
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+26
-8
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+3
-0
src/program.cpp
src/program.cpp
+15
-4
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+2
-2
test/eval_test.cpp
test/eval_test.cpp
+48
-0
No files found.
src/auto_contiguous.cpp
View file @
d2d5fd19
...
@@ -14,9 +14,7 @@ void auto_contiguous::apply(program& p) const
...
@@ -14,9 +14,7 @@ void auto_contiguous::apply(program& p) const
if
(
not
s
.
standard
())
if
(
not
s
.
standard
())
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
contiguous
{},
ins
);
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
contiguous
{},
ins
);
p
.
replace_instructions
(
ins
,
ins
,
std
::
next
(
c
));
p
.
replace_instruction
(
ins
,
c
);
// auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
// p.replace_instruction(ins, contiguous{}, prev);
}
}
}
}
}
}
...
...
src/include/migraph/instruction.hpp
View file @
d2d5fd19
...
@@ -55,6 +55,7 @@ struct instruction
...
@@ -55,6 +55,7 @@ struct instruction
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
recompute_shape
();
recompute_shape
();
}
}
...
@@ -62,7 +63,7 @@ struct instruction
...
@@ -62,7 +63,7 @@ struct instruction
{
{
for
(
auto
&&
arg
:
arguments
)
for
(
auto
&&
arg
:
arguments
)
{
{
migraph
::
erase
(
arg
->
output
,
*
this
);
arg
->
remove_
output
(
*
this
);
}
}
arguments
.
clear
();
arguments
.
clear
();
}
}
...
@@ -73,6 +74,16 @@ struct instruction
...
@@ -73,6 +74,16 @@ struct instruction
}
}
bool
valid
(
instruction_ref
start
)
const
bool
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
);
return
self
!=
i
->
output
.
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
valid
()
const
{
{
shape
computed
;
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
@@ -100,12 +111,7 @@ struct instruction
...
@@ -100,12 +111,7 @@ struct instruction
[
&
](
instruction_ref
i
)
{
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
i
->
arguments
.
end
();
})
&&
});
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
);
return
self
!=
i
->
output
.
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
...
@@ -114,6 +120,18 @@ struct instruction
...
@@ -114,6 +120,18 @@ struct instruction
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
operation
op
;
operation
op
;
shape
result
;
shape
result
;
std
::
vector
<
instruction_ref
>
output
;
std
::
vector
<
instruction_ref
>
output
;
...
@@ -124,7 +142,7 @@ struct instruction
...
@@ -124,7 +142,7 @@ struct instruction
inline
void
backreference
(
instruction_ref
ref
)
inline
void
backreference
(
instruction_ref
ref
)
{
{
for
(
auto
&&
arg
:
ref
->
arguments
)
for
(
auto
&&
arg
:
ref
->
arguments
)
arg
->
output
.
push_back
(
ref
);
arg
->
add_
output
(
ref
);
}
}
// TODO: Move to a cpp file
// TODO: Move to a cpp file
...
...
src/include/migraph/program.hpp
View file @
d2d5fd19
...
@@ -55,6 +55,9 @@ struct program
...
@@ -55,6 +55,9 @@ struct program
instruction_ref
instruction_ref
replace_instructions
(
instruction_ref
ins
,
instruction_ref
start
,
instruction_ref
last
);
replace_instructions
(
instruction_ref
ins
,
instruction_ref
start
,
instruction_ref
last
);
instruction_ref
replace_instruction
(
instruction_ref
ins
,
instruction_ref
start
);
instruction_ref
remove_instruction
(
instruction_ref
ins
);
instruction_ref
remove_instruction
(
instruction_ref
ins
);
instruction_ref
remove_instructions
(
instruction_ref
first
,
instruction_ref
last
);
instruction_ref
remove_instructions
(
instruction_ref
first
,
instruction_ref
last
);
...
...
src/program.cpp
View file @
d2d5fd19
...
@@ -38,6 +38,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
...
@@ -38,6 +38,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
args
});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
args
});
backreference
(
result
);
backreference
(
result
);
assert
(
result
->
arguments
==
args
);
assert
(
result
->
arguments
==
args
);
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -52,6 +53,7 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
...
@@ -52,6 +53,7 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
ins
->
replace
(
op
,
r
,
args
);
ins
->
replace
(
op
,
r
,
args
);
backreference
(
ins
);
backreference
(
ins
);
assert
(
ins
->
valid
(
begin
()));
return
ins
;
return
ins
;
}
}
...
@@ -61,16 +63,25 @@ program::replace_instructions(instruction_ref ins, instruction_ref start, instru
...
@@ -61,16 +63,25 @@ program::replace_instructions(instruction_ref ins, instruction_ref start, instru
auto
rep
=
std
::
prev
(
last
);
auto
rep
=
std
::
prev
(
last
);
for
(
auto
&&
out
:
ins
->
output
)
for
(
auto
&&
out
:
ins
->
output
)
{
{
if
(
std
::
find
(
start
,
last
,
out
)
==
last
)
if
(
std
::
find
(
start
,
last
,
out
)
==
last
)
{
{
out
->
replace_argument
(
ins
,
rep
);
out
->
replace_argument
(
ins
,
rep
);
backreference
(
out
);
backreference
(
out
);
}
}
assert
(
out
->
valid
(
begin
()));
}
}
assert
(
rep
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
if
(
ins
->
output
.
empty
())
if
(
ins
->
output
.
empty
())
return
remove_instruction
(
ins
);
remove_instruction
(
ins
);
return
ins
;
return
rep
;
}
instruction_ref
program
::
replace_instruction
(
instruction_ref
ins
,
instruction_ref
start
)
{
assert
(
ins
!=
start
);
return
replace_instructions
(
ins
,
start
,
std
::
next
(
start
));
}
}
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
...
@@ -182,7 +193,7 @@ void program::compile(const target& t)
...
@@ -182,7 +193,7 @@ void program::compile(const target& t)
{
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
));
std
::
to_string
(
index
)
+
": "
+
invalid
->
op
.
name
()
);
}
}
#endif
#endif
}
}
...
...
test/auto_contiguous_test.cpp
View file @
d2d5fd19
...
@@ -27,7 +27,7 @@ migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type,
...
@@ -27,7 +27,7 @@ migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type,
migraph
::
literal
get_2_broadcasted
()
migraph
::
literal
get_2_broadcasted
()
{
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
},
{
1
,
0
}},
{
1
,
2
}};
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
}
void
literal_broadcast
()
void
literal_broadcast
()
...
@@ -116,7 +116,7 @@ void after_param_broadcast()
...
@@ -116,7 +116,7 @@ void after_param_broadcast()
int
main
()
int
main
()
{
{
literal_broadcast
();
//
literal_broadcast();
literal_transpose
();
literal_transpose
();
after_literal_transpose
();
after_literal_transpose
();
after_literal_broadcast
();
after_literal_broadcast
();
...
...
test/eval_test.cpp
View file @
d2d5fd19
...
@@ -118,6 +118,51 @@ void replace_test()
...
@@ -118,6 +118,51 @@ void replace_test()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
}
void
replace_ins_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instruction
(
sum
,
minus
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
1
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_ins_test2
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instruction
(
two
,
sum
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
2
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_inss_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instructions
(
two
,
two
,
std
::
next
(
sum
));
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
2
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
insert_replace_test
()
void
insert_replace_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -181,6 +226,9 @@ int main()
...
@@ -181,6 +226,9 @@ int main()
print_test
();
print_test
();
param_test
();
param_test
();
replace_test
();
replace_test
();
replace_ins_test
();
replace_ins_test2
();
replace_inss_test
();
insert_replace_test
();
insert_replace_test
();
target_test
();
target_test
();
reverse_target_test
();
reverse_target_test
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment