Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fbaec470
"src/vscode:/vscode.git/clone" did not exist on "b3086ac2606d4b6999788f7faf06afa30406e44e"
Commit
fbaec470
authored
Mar 11, 2019
by
Paul
Browse files
Reduce the number identity ops added
parent
b211af48
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
55 deletions
+87
-55
src/schedule.cpp
src/schedule.cpp
+30
-7
test/schedule_test.cpp
test/schedule_test.cpp
+57
-48
No files found.
src/schedule.cpp
View file @
fbaec470
...
...
@@ -379,23 +379,46 @@ void schedule::apply(program& p) const
// Add memory conflicts
auto
concur_ins
=
si
.
find_concurrent_instructions
(
p
);
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
conflict_table
;
for
(
auto
&&
merge
:
concur_ins
)
{
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
return
;
if
(
merge
.
second
[
i
].
empty
())
return
;
if
(
merge
.
second
[
j
].
empty
())
return
;
for
(
auto
ins1
:
merge
.
second
[
i
])
{
auto
args
=
merge
.
second
[
j
];
args
.
insert
(
args
.
begin
(),
ins1
);
p
.
insert_instruction
(
merge
.
first
,
op
::
identity
{},
args
);
auto
p1
=
std
::
distance
(
ins1
,
merge
.
first
);
for
(
auto
ins2
:
merge
.
second
[
j
])
{
if
(
ins1
==
ins2
)
continue
;
auto
p2
=
std
::
distance
(
ins2
,
merge
.
first
);
// The smaller distance means the instruction occurs later
if
(
p1
>
p2
)
conflict_table
[
ins2
].
insert
(
ins1
);
else
conflict_table
[
ins1
].
insert
(
ins2
);
}
}
});
}
// Remove duplicates
for
(
auto
&&
ip
:
conflict_table
)
{
auto
ins1
=
ip
.
first
;
for
(
auto
ins2
:
ip
.
second
)
if
(
contains
(
conflict_table
[
ins2
],
ins1
))
conflict_table
[
ins2
].
erase
(
ins1
);
}
for
(
auto
&&
ip
:
conflict_table
)
{
if
(
ip
.
second
.
empty
())
continue
;
std
::
vector
<
instruction_ref
>
args
;
args
.
push_back
(
ip
.
first
);
args
.
insert
(
args
.
end
(),
ip
.
second
.
begin
(),
ip
.
second
.
end
());
p
.
insert_instruction
(
std
::
next
(
ip
.
first
),
op
::
identity
{},
args
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
test/schedule_test.cpp
View file @
fbaec470
...
...
@@ -140,6 +140,21 @@ struct schedule_model_test
}
};
bool
check_conflicts
(
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
y
)
{
for
(
auto
ins
:
migraphx
::
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"identity"
)
continue
;
if
(
not
migraphx
::
contains
(
ins
->
inputs
(),
x
))
continue
;
if
(
not
migraphx
::
contains
(
ins
->
inputs
(),
y
))
continue
;
return
true
;
}
return
false
;
}
struct
schedule_target
{
schedule_model_test
model
{};
...
...
@@ -162,35 +177,29 @@ struct schedule_target
}
bool
has_stream
(
migraphx
::
instruction_ref
ins
)
{
return
model
.
ins2stream
->
count
(
ins
)
>
0
;
}
};
bool
check_conflicts
(
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
y
)
{
for
(
auto
ins
:
migraphx
::
iterator_for
(
p
))
void
check_conflicts
(
migraphx
::
program
&
p
,
std
::
vector
<
std
::
vector
<
migraphx
::
instruction_ref
>>
conflicts
,
bool
result
=
true
)
{
if
(
ins
->
name
()
!=
"identity"
)
continue
;
if
(
not
migraphx
::
contains
(
ins
->
inputs
(),
x
))
continue
;
if
(
not
migraphx
::
contains
(
ins
->
inputs
(),
y
))
continue
;
return
true
;
migraphx
::
dfor
(
conflicts
.
size
(),
conflicts
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
return
;
for
(
auto
ins1
:
conflicts
[
i
])
{
for
(
auto
ins2
:
conflicts
[
j
])
{
// If both instructions are on the same stream then dont check for a conflict
if
(
has_stream
(
ins1
)
and
has_stream
(
ins2
)
and
get_stream
(
ins1
)
==
get_stream
(
ins2
))
continue
;
CHECK
(
::
check_conflicts
(
p
,
ins1
,
ins2
)
==
result
);
}
}
});
}
return
false
;
}
}
;
void
check_conflicts
(
migraphx
::
program
&
p
,
std
::
vector
<
std
::
vector
<
migraphx
::
instruction_ref
>>
conflicts
,
bool
result
=
true
)
{
migraphx
::
dfor
(
conflicts
.
size
(),
conflicts
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
return
;
for
(
auto
ins1
:
conflicts
[
i
])
for
(
auto
ins2
:
conflicts
[
j
])
CHECK
(
check_conflicts
(
p
,
ins1
,
ins2
)
==
result
);
});
}
template
<
class
T
>
std
::
vector
<
T
>
sorted
(
std
::
vector
<
T
>
x
)
...
...
@@ -292,7 +301,7 @@ TEST_CASE(zero_record)
EXPECT
(
get_wait_for
(
binary
)
==
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep1
),
t
.
get_stream
(
onep2
)}));
EXPECT
(
check_conflicts
(
p
,
onep1
,
onep2
));
check_conflicts
(
p
,
{{
onep1
,
onei1
},
{
onep2
,
onei2
}});
t
.
check_conflicts
(
p
,
{{
onep1
,
onei1
},
{
onep2
,
onei2
}});
}
TEST_CASE
(
zero_merge1
)
...
...
@@ -397,7 +406,7 @@ TEST_CASE(double_entry)
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
get_wait_for
(
binary
)
==
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep
),
t
.
get_stream
(
twop
)}));
check_conflicts
(
p
,
{{
onep
,
one
},
{
twop
,
two
}});
t
.
check_conflicts
(
p
,
{{
onep
,
one
},
{
twop
,
two
}});
}
TEST_CASE
(
two_branches
)
...
...
@@ -416,7 +425,7 @@ TEST_CASE(two_branches)
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
get_wait_for
(
binary
)
==
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
}
TEST_CASE
(
four_branches
)
...
...
@@ -444,7 +453,7 @@ TEST_CASE(four_branches)
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
c3
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
{
i1
}});
}
TEST_CASE
(
five_branches
)
...
...
@@ -475,8 +484,8 @@ TEST_CASE(five_branches)
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
c3
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
c4
});
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
c4
});
t
.
check_conflicts
(
p
,
{
c1
,
c2
,
c3
,
{
i1
}});
}
TEST_CASE
(
four_branches_eq
)
...
...
@@ -502,7 +511,7 @@ TEST_CASE(four_branches_eq)
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep1
),
t
.
get_stream
(
onep2
),
t
.
get_stream
(
onep3
),
t
.
get_stream
(
onep4
)}));
check_conflicts
(
p
,
{{
onep1
},
{
onep2
},
{
onep3
},
{
onep4
}});
t
.
check_conflicts
(
p
,
{{
onep1
},
{
onep2
},
{
onep3
},
{
onep4
}});
}
TEST_CASE
(
seq_merge
)
...
...
@@ -527,7 +536,7 @@ TEST_CASE(seq_merge)
EXPECT
(
t
.
get_stream
(
binary1
)
==
t
.
get_stream
(
c1
.
back
()));
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
EXPECT
(
t
.
get_stream
(
i2
)
!=
t
.
get_stream
(
c2
.
back
()));
for
(
auto
ins
:
c2
)
...
...
@@ -535,7 +544,7 @@ TEST_CASE(seq_merge)
EXPECT
(
t
.
get_stream
(
binary2
)
==
0
);
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
check_conflicts
(
p
,
{
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c2
,
{
i2
}});
}
TEST_CASE
(
par_merge
)
...
...
@@ -565,17 +574,17 @@ TEST_CASE(par_merge)
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
check_conflicts
(
p
,
{
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c2
,
{
i2
}});
EXPECT
(
check_conflicts
(
p
,
binary1
,
binary2
));
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
}
TEST_CASE
(
inner_par_merge
)
...
...
@@ -616,17 +625,17 @@ TEST_CASE(inner_par_merge)
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
check_conflicts
(
p
,
{
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c2
,
{
i2
}});
EXPECT
(
check_conflicts
(
p
,
binary1
,
binary2
));
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
},
{
outer1
},
{
outer2
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
},
{
outer1
},
{
outer2
}});
}
TEST_CASE
(
par_merge_multi_entry
)
...
...
@@ -658,17 +667,17 @@ TEST_CASE(par_merge_multi_entry)
EXPECT
(
t
.
get_stream
(
binary1
)
==
0
);
EXPECT
(
get_wait_for
(
binary1
)
==
get_wait_for
(
t
.
get_stream
(
binary1
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
for
(
auto
ins
:
c2
)
EXPECT
(
t
.
get_stream
(
ins
)
==
3
);
EXPECT
(
t
.
get_stream
(
binary2
)
==
3
);
EXPECT
(
get_wait_for
(
binary2
)
==
get_wait_for
(
t
.
get_stream
(
binary2
),
{
t
.
get_stream
(
c2
.
back
()),
t
.
get_stream
(
i2
)}));
check_conflicts
(
p
,
{
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c2
,
{
i2
}});
EXPECT
(
check_conflicts
(
p
,
binary1
,
binary2
));
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
},
c2
,
{
i2
}});
}
TEST_CASE
(
inner_split1
)
...
...
@@ -696,7 +705,7 @@ TEST_CASE(inner_split1)
EXPECT
(
get_wait_for
(
s1
).
empty
());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
check_conflicts
(
p
,
{
c1
,
{
i1
},
{
s1
},
{
s2
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
},
{
s1
},
{
s2
}});
}
TEST_CASE
(
inner_split2
)
...
...
@@ -722,7 +731,7 @@ TEST_CASE(inner_split2)
get_wait_for
(
t
.
get_stream
(
output
),
{
t
.
get_stream
(
i1
),
t
.
get_stream
(
s1
.
back
()),
t
.
get_stream
(
s2
.
back
())}));
EXPECT
(
get_wait_for
(
s1
.
front
())
==
get_wait_for
({
t
.
get_stream
(
c1
.
back
())}));
check_conflicts
(
p
,
{
c1
,
{
i1
},
s1
,
s2
});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
},
s1
,
s2
});
}
TEST_CASE
(
inception_resnet
)
...
...
@@ -745,7 +754,7 @@ TEST_CASE(inception_resnet)
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
c1
.
back
()),
t
.
get_stream
(
i1
)}));
EXPECT
(
t
.
get_stream
(
output
)
==
0
);
EXPECT
(
get_wait_for
(
output
).
empty
());
check_conflicts
(
p
,
{
c1
,
{
i1
}});
t
.
check_conflicts
(
p
,
{
c1
,
{
i1
}});
}
TEST_CASE
(
inception1
)
...
...
@@ -866,7 +875,7 @@ TEST_CASE(inception1)
get_wait_for
(
t
.
get_stream
(
output
),
{
t
.
get_stream
(
i94
),
t
.
get_stream
(
i75
),
t
.
get_stream
(
i61
),
t
.
get_stream
(
i86
)}));
check_conflicts
(
p
,
{{
i80
,
i86
},
{
i69
,
i75
},
{
i48
,
i54
,
i61
},
{
i94
}});
t
.
check_conflicts
(
p
,
{{
i80
,
i86
},
{
i69
,
i75
},
{
i48
,
i54
,
i61
},
{
i94
}});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment