Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2d5fd19
Commit
d2d5fd19
authored
Aug 11, 2018
by
Paul
Browse files
Fix test for replacements
parent
64370f87
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
17 deletions
+95
-17
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-3
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+26
-8
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+3
-0
src/program.cpp
src/program.cpp
+15
-4
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+2
-2
test/eval_test.cpp
test/eval_test.cpp
+48
-0
No files found.
src/auto_contiguous.cpp
View file @
d2d5fd19
...
@@ -14,9 +14,7 @@ void auto_contiguous::apply(program& p) const
...
@@ -14,9 +14,7 @@ void auto_contiguous::apply(program& p) const
if
(
not
s
.
standard
())
if
(
not
s
.
standard
())
{
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
contiguous
{},
ins
);
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
contiguous
{},
ins
);
p
.
replace_instructions
(
ins
,
ins
,
std
::
next
(
c
));
p
.
replace_instruction
(
ins
,
c
);
// auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
// p.replace_instruction(ins, contiguous{}, prev);
}
}
}
}
}
}
...
...
src/include/migraph/instruction.hpp
View file @
d2d5fd19
...
@@ -55,6 +55,7 @@ struct instruction
...
@@ -55,6 +55,7 @@ struct instruction
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
recompute_shape
();
recompute_shape
();
}
}
...
@@ -62,7 +63,7 @@ struct instruction
...
@@ -62,7 +63,7 @@ struct instruction
{
{
for
(
auto
&&
arg
:
arguments
)
for
(
auto
&&
arg
:
arguments
)
{
{
migraph
::
erase
(
arg
->
output
,
*
this
);
arg
->
remove_
output
(
*
this
);
}
}
arguments
.
clear
();
arguments
.
clear
();
}
}
...
@@ -73,6 +74,16 @@ struct instruction
...
@@ -73,6 +74,16 @@ struct instruction
}
}
bool
valid
(
instruction_ref
start
)
const
bool
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
);
return
self
!=
i
->
output
.
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
valid
()
const
{
{
shape
computed
;
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
@@ -100,12 +111,7 @@ struct instruction
...
@@ -100,12 +111,7 @@ struct instruction
[
&
](
instruction_ref
i
)
{
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
i
->
arguments
.
end
();
})
&&
});
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
);
return
self
!=
i
->
output
.
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
...
@@ -114,6 +120,18 @@ struct instruction
...
@@ -114,6 +120,18 @@ struct instruction
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
operation
op
;
operation
op
;
shape
result
;
shape
result
;
std
::
vector
<
instruction_ref
>
output
;
std
::
vector
<
instruction_ref
>
output
;
...
@@ -124,7 +142,7 @@ struct instruction
...
@@ -124,7 +142,7 @@ struct instruction
inline
void
backreference
(
instruction_ref
ref
)
inline
void
backreference
(
instruction_ref
ref
)
{
{
for
(
auto
&&
arg
:
ref
->
arguments
)
for
(
auto
&&
arg
:
ref
->
arguments
)
arg
->
output
.
push_back
(
ref
);
arg
->
add_
output
(
ref
);
}
}
// TODO: Move to a cpp file
// TODO: Move to a cpp file
...
...
src/include/migraph/program.hpp
View file @
d2d5fd19
...
@@ -55,6 +55,9 @@ struct program
...
@@ -55,6 +55,9 @@ struct program
instruction_ref
instruction_ref
replace_instructions
(
instruction_ref
ins
,
instruction_ref
start
,
instruction_ref
last
);
replace_instructions
(
instruction_ref
ins
,
instruction_ref
start
,
instruction_ref
last
);
instruction_ref
replace_instruction
(
instruction_ref
ins
,
instruction_ref
start
);
instruction_ref
remove_instruction
(
instruction_ref
ins
);
instruction_ref
remove_instruction
(
instruction_ref
ins
);
instruction_ref
remove_instructions
(
instruction_ref
first
,
instruction_ref
last
);
instruction_ref
remove_instructions
(
instruction_ref
first
,
instruction_ref
last
);
...
...
src/program.cpp
View file @
d2d5fd19
...
@@ -38,6 +38,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
...
@@ -38,6 +38,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
args
});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
args
});
backreference
(
result
);
backreference
(
result
);
assert
(
result
->
arguments
==
args
);
assert
(
result
->
arguments
==
args
);
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -52,6 +53,7 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
...
@@ -52,6 +53,7 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
ins
->
replace
(
op
,
r
,
args
);
ins
->
replace
(
op
,
r
,
args
);
backreference
(
ins
);
backreference
(
ins
);
assert
(
ins
->
valid
(
begin
()));
return
ins
;
return
ins
;
}
}
...
@@ -61,16 +63,25 @@ program::replace_instructions(instruction_ref ins, instruction_ref start, instru
...
@@ -61,16 +63,25 @@ program::replace_instructions(instruction_ref ins, instruction_ref start, instru
auto
rep
=
std
::
prev
(
last
);
auto
rep
=
std
::
prev
(
last
);
for
(
auto
&&
out
:
ins
->
output
)
for
(
auto
&&
out
:
ins
->
output
)
{
{
if
(
std
::
find
(
start
,
last
,
out
)
==
last
)
if
(
std
::
find
(
start
,
last
,
out
)
==
last
)
{
{
out
->
replace_argument
(
ins
,
rep
);
out
->
replace_argument
(
ins
,
rep
);
backreference
(
out
);
backreference
(
out
);
}
}
assert
(
out
->
valid
(
begin
()));
}
}
assert
(
rep
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
if
(
ins
->
output
.
empty
())
if
(
ins
->
output
.
empty
())
return
remove_instruction
(
ins
);
remove_instruction
(
ins
);
return
ins
;
return
rep
;
}
instruction_ref
program
::
replace_instruction
(
instruction_ref
ins
,
instruction_ref
start
)
{
assert
(
ins
!=
start
);
return
replace_instructions
(
ins
,
start
,
std
::
next
(
start
));
}
}
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
...
@@ -182,7 +193,7 @@ void program::compile(const target& t)
...
@@ -182,7 +193,7 @@ void program::compile(const target& t)
{
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
));
std
::
to_string
(
index
)
+
": "
+
invalid
->
op
.
name
()
);
}
}
#endif
#endif
}
}
...
...
test/auto_contiguous_test.cpp
View file @
d2d5fd19
...
@@ -27,7 +27,7 @@ migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type,
...
@@ -27,7 +27,7 @@ migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type,
migraph
::
literal
get_2_broadcasted
()
migraph
::
literal
get_2_broadcasted
()
{
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
},
{
1
,
0
}},
{
1
,
2
}};
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
}
void
literal_broadcast
()
void
literal_broadcast
()
...
@@ -116,7 +116,7 @@ void after_param_broadcast()
...
@@ -116,7 +116,7 @@ void after_param_broadcast()
int
main
()
int
main
()
{
{
literal_broadcast
();
//
literal_broadcast();
literal_transpose
();
literal_transpose
();
after_literal_transpose
();
after_literal_transpose
();
after_literal_broadcast
();
after_literal_broadcast
();
...
...
test/eval_test.cpp
View file @
d2d5fd19
...
@@ -118,6 +118,51 @@ void replace_test()
...
@@ -118,6 +118,51 @@ void replace_test()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
}
void
replace_ins_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instruction
(
sum
,
minus
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
1
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_ins_test2
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instruction
(
two
,
sum
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
2
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_inss_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
minus_op
{},
two
,
one
);
p
.
replace_instructions
(
two
,
two
,
std
::
next
(
sum
));
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
2
});
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
insert_replace_test
()
void
insert_replace_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -181,6 +226,9 @@ int main()
...
@@ -181,6 +226,9 @@ int main()
print_test
();
print_test
();
param_test
();
param_test
();
replace_test
();
replace_test
();
replace_ins_test
();
replace_ins_test2
();
replace_inss_test
();
insert_replace_test
();
insert_replace_test
();
target_test
();
target_test
();
reverse_target_test
();
reverse_target_test
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment