Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
473a045a
Commit
473a045a
authored
Mar 08, 2019
by
Paul
Browse files
Fix memory conflicts
parent
84221940
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
45 deletions
+97
-45
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+54
-5
src/schedule.cpp
src/schedule.cpp
+38
-38
test/schedule_test.cpp
test/schedule_test.cpp
+5
-2
No files found.
src/include/migraphx/iterator_for.hpp
View file @
473a045a
...
@@ -3,21 +3,64 @@
...
@@ -3,21 +3,64 @@
#include <cassert>
#include <cassert>
#include <type_traits>
#include <type_traits>
#include <iterator>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
struct
iterator_for_select
{
template
<
class
T
>
static
T
deref
(
T
x
)
{
return
x
;
}
template
<
class
T
>
static
auto
begin
(
T
*
x
)
{
return
x
->
begin
();
}
template
<
class
T
>
static
auto
end
(
T
*
x
)
{
return
x
->
end
();
}
};
struct
iterator_for_select_reverse
{
template
<
class
T
>
static
auto
deref
(
T
x
)
{
return
std
::
prev
(
x
.
base
());
}
template
<
class
T
>
static
auto
begin
(
T
*
x
)
{
return
std
::
make_reverse_iterator
(
x
->
end
());
}
template
<
class
T
>
static
auto
end
(
T
*
x
)
{
return
std
::
make_reverse_iterator
(
x
->
begin
());
}
};
template
<
class
T
,
class
Selector
=
iterator_for_select
>
struct
iterator_for_range
struct
iterator_for_range
{
{
T
*
base
;
T
*
base
;
using
base_iterator
=
std
::
remove_reference_t
<
decltype
(
base
->
begin
())
>
;
using
base_iterator
=
std
::
remove_reference_t
<
decltype
(
Selector
::
begin
(
base
))
>
;
struct
iterator
struct
iterator
{
{
base_iterator
i
;
base_iterator
i
;
base_iter
ato
r
operator
*
()
const
{
return
i
;
}
a
u
to
operator
*
()
const
{
return
Selector
::
deref
(
i
)
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
};
...
@@ -25,12 +68,12 @@ struct iterator_for_range
...
@@ -25,12 +68,12 @@ struct iterator_for_range
iterator
begin
()
iterator
begin
()
{
{
assert
(
base
!=
nullptr
);
assert
(
base
!=
nullptr
);
return
{
base
->
begin
()};
return
{
Selector
::
begin
(
base
)};
}
}
iterator
end
()
iterator
end
()
{
{
assert
(
base
!=
nullptr
);
assert
(
base
!=
nullptr
);
return
{
base
->
end
(
)};
return
{
Selector
::
end
(
base
)};
}
}
};
};
template
<
class
T
>
template
<
class
T
>
...
@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
...
@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return
{
&
x
};
return
{
&
x
};
}
}
template
<
class
T
>
iterator_for_range
<
T
,
iterator_for_select_reverse
>
reverse_iterator_for
(
T
&
x
)
{
return
{
&
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/schedule.cpp
View file @
473a045a
...
@@ -162,7 +162,7 @@ struct stream_info
...
@@ -162,7 +162,7 @@ struct stream_info
}
}
template
<
class
Selector
>
template
<
class
Selector
>
auto
get_streams
(
instruction_ref
start
,
Selector
select
)
const
auto
get_streams
_from
(
instruction_ref
start
,
Selector
select
)
const
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
return
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
)
{
...
@@ -184,16 +184,28 @@ struct stream_info
...
@@ -184,16 +184,28 @@ struct stream_info
};
};
}
}
std
::
unordered_set
<
std
::
size_t
>
get_streams
(
instruction_ref
ins
)
{
if
(
has_stream
(
ins
))
return
{
get_stream
(
ins
)};
std
::
unordered_set
<
std
::
size_t
>
result
;
get_streams_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
result
.
insert
(
s
);
return
true
;
});
return
result
;
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
bool
is_merge_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
bool
is_merge_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
{
{
return
different
(
get_streams
(
ins
,
get_inputs
()),
xs
...);
return
different
(
get_streams
_from
(
ins
,
get_inputs
()),
xs
...);
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
bool
is_split_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
bool
is_split_point
(
instruction_ref
ins
,
Ts
...
xs
)
const
{
{
return
different
(
get_streams
(
ins
,
get_outputs
()),
xs
...);
return
different
(
get_streams
_from
(
ins
,
get_outputs
()),
xs
...);
}
}
std
::
vector
<
instruction_ref
>
get_recorded_instructions
(
instruction_ref
start
)
std
::
vector
<
instruction_ref
>
get_recorded_instructions
(
instruction_ref
start
)
...
@@ -225,7 +237,7 @@ struct stream_info
...
@@ -225,7 +237,7 @@ struct stream_info
std
::
vector
<
std
::
size_t
>
wait_for
(
instruction_ref
ins
)
const
std
::
vector
<
std
::
size_t
>
wait_for
(
instruction_ref
ins
)
const
{
{
std
::
vector
<
std
::
size_t
>
result
;
std
::
vector
<
std
::
size_t
>
result
;
get_streams
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
get_streams
_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
result
.
push_back
(
s
);
result
.
push_back
(
s
);
return
true
;
return
true
;
});
});
...
@@ -243,35 +255,27 @@ struct stream_info
...
@@ -243,35 +255,27 @@ struct stream_info
find_concurrent_instructions
(
program
&
p
)
find_concurrent_instructions
(
program
&
p
)
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
split
_from
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge
_from
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
reverse_
iterator_for
(
p
))
{
{
if
(
iweights
[
ins
]
==
0
)
for
(
auto
&&
arg
:
ins
->
outputs
())
continue
;
for
(
auto
&&
arg
:
ins
->
inputs
())
{
{
if
(
is_
split
_point
(
arg
))
if
(
is_
merge
_point
(
arg
))
split
_from
[
ins
].
insert
(
arg
);
merge
_from
[
ins
].
insert
(
arg
);
split
_from
[
ins
].
insert
(
split
_from
[
arg
].
begin
(),
split
_from
[
arg
].
end
());
merge
_from
[
ins
].
insert
(
merge
_from
[
arg
].
begin
(),
merge
_from
[
arg
].
end
());
}
}
auto
stream
=
get_stream
(
ins
);
auto
streams
=
get_streams
(
ins
);
// if (is_merge_point(ins))
// {
// Collect concur instructions for each merge point.
// // post-dominator kills split point.
for
(
auto
&
merge
:
merge_from
[
ins
])
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point.
for
(
auto
&
split
:
split_from
[
ins
])
{
{
if
(
result
[
split
].
size
()
<=
stream
)
for
(
auto
stream
:
streams
)
result
[
split
].
resize
(
stream
+
1
);
{
result
[
split
][
stream
].
push_back
(
ins
);
if
(
result
[
merge
].
size
()
<=
stream
)
result
[
merge
].
resize
(
stream
+
1
);
result
[
merge
][
stream
].
push_back
(
ins
);
}
}
}
}
}
return
result
;
return
result
;
...
@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
...
@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
std
::
cout
<<
":"
;
std
::
cout
<<
":"
;
std
::
cout
<<
" weight="
<<
si
.
weights
.
at
(
ins
);
std
::
cout
<<
" weight="
<<
si
.
weights
.
at
(
ins
);
std
::
cout
<<
" input={"
;
std
::
cout
<<
" input={"
;
si
.
get_streams
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
si
.
get_streams
_from
(
ins
,
get_inputs
())([
&
](
auto
s
)
{
std
::
cout
<<
s
<<
","
;
std
::
cout
<<
s
<<
","
;
return
true
;
return
true
;
});
});
...
@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
...
@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
// Add memory conflicts
// Add memory conflicts
auto
concur_ins
=
si
.
find_concurrent_instructions
(
p
);
auto
concur_ins
=
si
.
find_concurrent_instructions
(
p
);
for
(
auto
&&
split
:
concur_ins
)
for
(
auto
&&
merge
:
concur_ins
)
{
{
dfor
(
split
.
second
.
size
(),
split
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
dfor
(
merge
.
second
.
size
(),
merge
.
second
.
size
())([
&
](
auto
i
,
auto
j
)
{
if
(
i
==
j
)
if
(
i
==
j
)
return
;
return
;
for
(
auto
ins1
:
split
.
second
[
i
])
for
(
auto
ins1
:
merge
.
second
[
i
])
{
{
auto
args
=
split
.
second
[
j
];
auto
args
=
merge
.
second
[
j
];
args
.
insert
(
args
.
begin
(),
ins1
);
args
.
insert
(
args
.
begin
(),
ins1
);
p
.
insert_instruction
(
merge
.
first
,
op
::
identity
{},
args
);
auto
point
=
std
::
max_element
(
args
.
begin
(),
args
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
std
::
distance
(
split
.
first
,
x
)
<
std
::
distance
(
split
.
first
,
y
);
});
p
.
insert_instruction
(
std
::
next
(
*
point
),
op
::
identity
{},
args
);
}
}
});
});
}
}
...
...
test/schedule_test.cpp
View file @
473a045a
...
@@ -94,7 +94,7 @@ struct schedule_model_test
...
@@ -94,7 +94,7 @@ struct schedule_model_test
}
}
(
*
ins2wait_for
)[
ins
]
->
push_back
(
wait2stream
->
at
(
wait_id
));
(
*
ins2wait_for
)[
ins
]
->
push_back
(
wait2stream
->
at
(
wait_id
));
}
}
void
record
(
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
record
(
migraphx
::
program
&
,
migraphx
::
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
{
(
*
wait2stream
)[
wait_id
]
=
ins2stream
->
at
(
ins
);
(
*
wait2stream
)[
wait_id
]
=
ins2stream
->
at
(
ins
);
}
}
...
@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
...
@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
std
::
vector
<
std
::
size_t
>
get_wait_for
(
migraphx
::
instruction_ref
ins
)
std
::
vector
<
std
::
size_t
>
get_wait_for
(
migraphx
::
instruction_ref
ins
)
{
{
auto
wait_ins
=
std
::
prev
(
ins
);
auto
wait_ins
=
std
::
prev
(
ins
);
// Skip identity operators
while
(
wait_ins
->
name
()
==
"identity"
)
wait_ins
=
std
::
prev
(
wait_ins
);
if
(
wait_ins
->
name
()
!=
"wait_event"
)
if
(
wait_ins
->
name
()
!=
"wait_event"
)
return
{};
return
{};
auto
wf
=
*
migraphx
::
any_cast
<
wait_event
>
(
wait_ins
->
get_operator
()).
wait_for
;
auto
wf
=
*
migraphx
::
any_cast
<
wait_event
>
(
wait_ins
->
get_operator
()).
wait_for
;
...
@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
...
@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
t
.
get_stream
(
binary
)
==
0
);
EXPECT
(
get_wait_for
(
binary
)
==
EXPECT
(
get_wait_for
(
binary
)
==
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep
),
t
.
get_stream
(
twop
)}));
get_wait_for
(
t
.
get_stream
(
binary
),
{
t
.
get_stream
(
onep
),
t
.
get_stream
(
twop
)}));
//
EXPECT(check_conflicts(p, onep, twop));
EXPECT
(
check_conflicts
(
p
,
onep
,
twop
));
}
}
TEST_CASE
(
two_branches
)
TEST_CASE
(
two_branches
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment